• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.hamcrest.CoreMatchers.not;
21 import static org.junit.Assert.assertArrayEquals;
22 import static org.junit.Assert.assertEquals;
23 import static org.junit.Assert.assertThat;
24 import static org.junit.Assert.assertTrue;
25 import static org.mockito.Mockito.any;
26 import static org.mockito.Mockito.eq;
27 import static org.mockito.Mockito.verify;
28 import static org.mockito.Mockito.when;
29 import static org.testng.Assert.expectThrows;
30 
31 import android.app.RemoteAction;
32 import android.content.Context;
33 import android.content.Intent;
34 import android.net.Uri;
35 import android.os.Bundle;
36 import android.os.LocaleList;
37 import android.text.Spannable;
38 import android.text.SpannableString;
39 import android.view.textclassifier.ConversationAction;
40 import android.view.textclassifier.ConversationActions;
41 import android.view.textclassifier.TextClassification;
42 import android.view.textclassifier.TextClassifier;
43 import android.view.textclassifier.TextLanguage;
44 import android.view.textclassifier.TextLinks;
45 import android.view.textclassifier.TextSelection;
46 import androidx.collection.LruCache;
47 import androidx.test.core.app.ApplicationProvider;
48 import androidx.test.ext.junit.runners.AndroidJUnit4;
49 import androidx.test.filters.SdkSuppress;
50 import androidx.test.filters.SmallTest;
51 import com.android.textclassifier.common.ModelFile;
52 import com.android.textclassifier.common.ModelType;
53 import com.android.textclassifier.common.TextClassifierSettings;
54 import com.android.textclassifier.testing.FakeContextBuilder;
55 import com.android.textclassifier.testing.TestingDeviceConfig;
56 import com.google.android.textclassifier.AnnotatorModel;
57 import com.google.common.collect.ImmutableList;
58 import java.io.IOException;
59 import java.util.ArrayList;
60 import java.util.Arrays;
61 import java.util.Collections;
62 import java.util.List;
63 import org.hamcrest.BaseMatcher;
64 import org.hamcrest.Description;
65 import org.hamcrest.Matcher;
66 import org.junit.Before;
67 import org.junit.Test;
68 import org.junit.runner.RunWith;
69 import org.mockito.Mock;
70 import org.mockito.MockitoAnnotations;
71 
72 @SmallTest
73 @RunWith(AndroidJUnit4.class)
74 public class TextClassifierImplTest {
75 
76   private static final String TYPE_COPY = "copy";
77   private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
78   private static final String NO_TYPE = null;
79 
80   @Mock private ModelFileManager modelFileManager;
81 
82   private Context context;
83   private TestingDeviceConfig deviceConfig;
84   private TextClassifierSettings settings;
85   private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
86   private TextClassifierImpl classifier;
87 
88   @Before
setup()89   public void setup() throws IOException {
90     MockitoAnnotations.initMocks(this);
91     this.context =
92         new FakeContextBuilder()
93             .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
94             .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
95             .build();
96     this.deviceConfig = new TestingDeviceConfig();
97     this.settings = new TextClassifierSettings(deviceConfig);
98     this.annotatorModelCache = new LruCache<>(2);
99     this.classifier =
100         new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
101 
102     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
103         .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
104     when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
105         .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
106     when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
107         .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
108   }
109 
110   @Test
testSuggestSelection()111   public void testSuggestSelection() throws IOException {
112     String text = "Contact me at droid@android.com";
113     String selected = "droid";
114     String suggested = "droid@android.com";
115     int startIndex = text.indexOf(selected);
116     int endIndex = startIndex + selected.length();
117     int smartStartIndex = text.indexOf(suggested);
118     int smartEndIndex = smartStartIndex + suggested.length();
119     TextSelection.Request request =
120         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
121 
122     TextSelection selection = classifier.suggestSelection(null, null, request);
123     assertThat(
124         selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
125   }
126 
127   @Test
testSuggestSelection_localePreferenceIsPassedToModelFileManager()128   public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
129     String text = "Contact me at droid@android.com";
130     String selected = "droid";
131     String suggested = "droid@android.com";
132     int startIndex = text.indexOf(selected);
133     int endIndex = startIndex + selected.length();
134     int smartStartIndex = text.indexOf(suggested);
135     int smartEndIndex = smartStartIndex + suggested.length();
136     TextSelection.Request request =
137         new TextSelection.Request.Builder(text, startIndex, endIndex)
138             .setDefaultLocales(LOCALES)
139             .build();
140 
141     classifier.suggestSelection(null, null, request);
142     verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
143   }
144 
145   @Test
testSuggestSelection_url()146   public void testSuggestSelection_url() throws IOException {
147     String text = "Visit http://www.android.com for more information";
148     String selected = "http";
149     String suggested = "http://www.android.com";
150     int startIndex = text.indexOf(selected);
151     int endIndex = startIndex + selected.length();
152     int smartStartIndex = text.indexOf(suggested);
153     int smartEndIndex = smartStartIndex + suggested.length();
154     TextSelection.Request request =
155         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
156 
157     TextSelection selection = classifier.suggestSelection(null, null, request);
158     assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
159   }
160 
161   @Test
testSmartSelection_withEmoji()162   public void testSmartSelection_withEmoji() throws IOException {
163     String text = "\uD83D\uDE02 Hello.";
164     String selected = "Hello";
165     int startIndex = text.indexOf(selected);
166     int endIndex = startIndex + selected.length();
167     TextSelection.Request request =
168         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
169 
170     TextSelection selection = classifier.suggestSelection(null, null, request);
171     assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
172   }
173 
174   @SdkSuppress(minSdkVersion = 31, codeName = "S")
175   @Test
testSuggestSelection_includeTextClassification()176   public void testSuggestSelection_includeTextClassification() throws IOException {
177     String text = "Visit http://www.android.com for more information";
178     String suggested = "http://www.android.com";
179     int startIndex = text.indexOf(suggested);
180     TextSelection.Request request =
181         new TextSelection.Request.Builder(text, startIndex, /*endIndex=*/ startIndex + 1)
182             .setIncludeTextClassification(true)
183             .build();
184 
185     TextSelection selection = classifier.suggestSelection(null, null, request);
186 
187     assertThat(
188         selection.getTextClassification(),
189         isTextClassification(suggested, TextClassifier.TYPE_URL));
190     assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW));
191   }
192 
193   @SdkSuppress(minSdkVersion = 31, codeName = "S")
194   @Test
testSuggestSelection_notIncludeTextClassification()195   public void testSuggestSelection_notIncludeTextClassification() throws IOException {
196     String text = "Visit http://www.android.com for more information";
197     TextSelection.Request request =
198         new TextSelection.Request.Builder(text, /*startIndex=*/ 0, /*endIndex=*/ 4)
199             .setIncludeTextClassification(false)
200             .build();
201 
202     TextSelection selection = classifier.suggestSelection(null, null, request);
203 
204     assertThat(selection.getTextClassification()).isNull();
205   }
206 
207   @Test
testClassifyText()208   public void testClassifyText() throws IOException {
209     String text = "Contact me at droid@android.com";
210     String classifiedText = "droid@android.com";
211     int startIndex = text.indexOf(classifiedText);
212     int endIndex = startIndex + classifiedText.length();
213     TextClassification.Request request =
214         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
215 
216     TextClassification classification =
217         classifier.classifyText(/* sessionId= */ null, null, request);
218     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
219   }
220 
221   @Test
testClassifyText_url()222   public void testClassifyText_url() throws IOException {
223     String text = "Visit www.android.com for more information";
224     String classifiedText = "www.android.com";
225     int startIndex = text.indexOf(classifiedText);
226     int endIndex = startIndex + classifiedText.length();
227     TextClassification.Request request =
228         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
229 
230     TextClassification classification = classifier.classifyText(null, null, request);
231     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
232     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
233   }
234 
235   @Test
testClassifyText_address()236   public void testClassifyText_address() throws IOException {
237     String text = "Brandschenkestrasse 110, Zürich, Switzerland";
238     TextClassification.Request request =
239         new TextClassification.Request.Builder(text, 0, text.length()).build();
240 
241     TextClassification classification = classifier.classifyText(null, null, request);
242     assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
243   }
244 
245   @Test
testClassifyText_url_inCaps()246   public void testClassifyText_url_inCaps() throws IOException {
247     String text = "Visit HTTP://ANDROID.COM for more information";
248     String classifiedText = "HTTP://ANDROID.COM";
249     int startIndex = text.indexOf(classifiedText);
250     int endIndex = startIndex + classifiedText.length();
251     TextClassification.Request request =
252         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
253 
254     TextClassification classification = classifier.classifyText(null, null, request);
255     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
256     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
257   }
258 
259   @Test
testClassifyText_date()260   public void testClassifyText_date() throws IOException {
261     String text = "Let's meet on January 9, 2018.";
262     String classifiedText = "January 9, 2018";
263     int startIndex = text.indexOf(classifiedText);
264     int endIndex = startIndex + classifiedText.length();
265     TextClassification.Request request =
266         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
267 
268     TextClassification classification = classifier.classifyText(null, null, request);
269     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
270     Bundle extras = classification.getExtras();
271     List<Bundle> entities = ExtrasUtils.getEntities(extras);
272     assertThat(entities).hasSize(1);
273     assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE);
274     ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification);
275     actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras);
276   }
277 
278   @Test
testClassifyText_datetime()279   public void testClassifyText_datetime() throws IOException {
280     String text = "Let's meet 2018/01/01 10:30:20.";
281     String classifiedText = "2018/01/01 10:30:20";
282     int startIndex = text.indexOf(classifiedText);
283     int endIndex = startIndex + classifiedText.length();
284     TextClassification.Request request =
285         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
286 
287     TextClassification classification = classifier.classifyText(null, null, request);
288     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
289   }
290 
291   @Test
testClassifyText_foreignText()292   public void testClassifyText_foreignText() throws IOException {
293     LocaleList originalLocales = LocaleList.getDefault();
294     LocaleList.setDefault(LocaleList.forLanguageTags("en"));
295     String japaneseText = "これは日本語のテキストです";
296     TextClassification.Request request =
297         new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
298 
299     TextClassification classification = classifier.classifyText(null, null, request);
300     RemoteAction translateAction = classification.getActions().get(0);
301     assertEquals(1, classification.getActions().size());
302     assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
303 
304     assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
305     Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
306     assertNoPackageInfoInExtras(intent);
307     assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
308     Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
309     assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
310     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
311     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
312     assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
313     assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first);
314 
315     LocaleList.setDefault(originalLocales);
316   }
317 
318   @Test
testGenerateLinks_phone()319   public void testGenerateLinks_phone() throws IOException {
320     String text = "The number is +12122537077. See you tonight!";
321     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
322     assertThat(
323         classifier.generateLinks(null, null, request),
324         isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
325   }
326 
327   @Test
testGenerateLinks_exclude()328   public void testGenerateLinks_exclude() throws IOException {
329     String text = "The number is +12122537077. See you tonight!";
330     List<String> hints = ImmutableList.of();
331     List<String> included = ImmutableList.of();
332     List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
333     TextLinks.Request request =
334         new TextLinks.Request.Builder(text)
335             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
336             .build();
337     assertThat(
338         classifier.generateLinks(null, null, request),
339         not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
340   }
341 
342   @Test
testGenerateLinks_explicit_address()343   public void testGenerateLinks_explicit_address() throws IOException {
344     String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
345     List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
346     TextLinks.Request request =
347         new TextLinks.Request.Builder(text)
348             .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
349             .build();
350     assertThat(
351         classifier.generateLinks(null, null, request),
352         isTextLinksContaining(
353             text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
354   }
355 
356   @Test
testGenerateLinks_exclude_override()357   public void testGenerateLinks_exclude_override() throws IOException {
358     String text = "You want apple@banana.com. See you tonight!";
359     List<String> hints = ImmutableList.of();
360     List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
361     List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
362     TextLinks.Request request =
363         new TextLinks.Request.Builder(text)
364             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
365             .build();
366     assertThat(
367         classifier.generateLinks(null, null, request),
368         not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
369   }
370 
371   @Test
testGenerateLinks_maxLength()372   public void testGenerateLinks_maxLength() throws IOException {
373     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
374     Arrays.fill(manySpaces, ' ');
375     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
376     TextLinks links = classifier.generateLinks(null, null, request);
377     assertTrue(links.getLinks().isEmpty());
378   }
379 
380   @Test
testApplyLinks_unsupportedCharacter()381   public void testApplyLinks_unsupportedCharacter() throws IOException {
382     Spannable url = new SpannableString("\u202Emoc.diordna.com");
383     TextLinks.Request request = new TextLinks.Request.Builder(url).build();
384     assertEquals(
385         TextLinks.STATUS_UNSUPPORTED_CHARACTER,
386         classifier.generateLinks(null, null, request).apply(url, 0, null));
387   }
388 
389   @Test
testGenerateLinks_tooLong()390   public void testGenerateLinks_tooLong() {
391     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
392     Arrays.fill(manySpaces, ' ');
393     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
394     expectThrows(
395         IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
396   }
397 
398   @Test
testGenerateLinks_entityData()399   public void testGenerateLinks_entityData() throws IOException {
400     String text = "The number is +12122537077.";
401     Bundle extras = new Bundle();
402     ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
403     TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
404 
405     TextLinks textLinks = classifier.generateLinks(null, null, request);
406 
407     assertThat(textLinks.getLinks()).hasSize(1);
408     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
409     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
410     assertThat(entities).hasSize(1);
411     Bundle entity = entities.get(0);
412     assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
413   }
414 
415   @Test
testGenerateLinks_entityData_disabled()416   public void testGenerateLinks_entityData_disabled() throws IOException {
417     String text = "The number is +12122537077.";
418     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
419 
420     TextLinks textLinks = classifier.generateLinks(null, null, request);
421 
422     assertThat(textLinks.getLinks()).hasSize(1);
423     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
424     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
425     assertThat(entities).isNull();
426   }
427 
428   @Test
testDetectLanguage()429   public void testDetectLanguage() throws IOException {
430     String text = "This is English text";
431     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
432     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
433     assertThat(textLanguage, isTextLanguage("en"));
434   }
435 
436   @Test
testDetectLanguage_japanese()437   public void testDetectLanguage_japanese() throws IOException {
438     String text = "これは日本語のテキストです";
439     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
440     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
441     assertThat(textLanguage, isTextLanguage("ja"));
442   }
443 
444   @Test
testSuggestConversationActions_textReplyOnly_maxOne()445   public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
446     ConversationActions.Message message =
447         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
448             .setText("Where are you?")
449             .build();
450     TextClassifier.EntityConfig typeConfig =
451         new TextClassifier.EntityConfig.Builder()
452             .includeTypesFromTextClassifier(false)
453             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
454             .build();
455     ConversationActions.Request request =
456         new ConversationActions.Request.Builder(Collections.singletonList(message))
457             .setMaxSuggestions(1)
458             .setTypeConfig(typeConfig)
459             .build();
460 
461     ConversationActions conversationActions =
462         classifier.suggestConversationActions(null, null, request);
463     assertThat(conversationActions.getConversationActions()).hasSize(1);
464     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
465     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
466     assertThat(conversationAction.getTextReply()).isNotNull();
467   }
468 
469   @Test
testSuggestConversationActions_textReplyOnly_noMax()470   public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
471     ConversationActions.Message message =
472         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
473             .setText("Where are you?")
474             .build();
475     TextClassifier.EntityConfig typeConfig =
476         new TextClassifier.EntityConfig.Builder()
477             .includeTypesFromTextClassifier(false)
478             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
479             .build();
480     ConversationActions.Request request =
481         new ConversationActions.Request.Builder(Collections.singletonList(message))
482             .setTypeConfig(typeConfig)
483             .build();
484 
485     ConversationActions conversationActions =
486         classifier.suggestConversationActions(null, null, request);
487     assertTrue(conversationActions.getConversationActions().size() > 1);
488     for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
489       assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
490     }
491   }
492 
493   @Test
testSuggestConversationActions_openUrl()494   public void testSuggestConversationActions_openUrl() throws IOException {
495     ConversationActions.Message message =
496         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
497             .setText("Check this out: https://www.android.com")
498             .build();
499     TextClassifier.EntityConfig typeConfig =
500         new TextClassifier.EntityConfig.Builder()
501             .includeTypesFromTextClassifier(false)
502             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
503             .build();
504     ConversationActions.Request request =
505         new ConversationActions.Request.Builder(Collections.singletonList(message))
506             .setMaxSuggestions(1)
507             .setTypeConfig(typeConfig)
508             .build();
509 
510     ConversationActions conversationActions =
511         classifier.suggestConversationActions(null, null, request);
512     assertThat(conversationActions.getConversationActions()).hasSize(1);
513     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
514     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
515     Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
516     assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
517     assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
518     assertNoPackageInfoInExtras(actionIntent);
519   }
520 
521   @Test
testSuggestConversationActions_copy()522   public void testSuggestConversationActions_copy() throws IOException {
523     ConversationActions.Message message =
524         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
525             .setText("Authentication code: 12345")
526             .build();
527     TextClassifier.EntityConfig typeConfig =
528         new TextClassifier.EntityConfig.Builder()
529             .includeTypesFromTextClassifier(false)
530             .setIncludedTypes(Collections.singletonList(TYPE_COPY))
531             .build();
532     ConversationActions.Request request =
533         new ConversationActions.Request.Builder(Collections.singletonList(message))
534             .setMaxSuggestions(1)
535             .setTypeConfig(typeConfig)
536             .build();
537 
538     ConversationActions conversationActions =
539         classifier.suggestConversationActions(null, null, request);
540     assertThat(conversationActions.getConversationActions()).hasSize(1);
541     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
542     assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
543     assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
544     assertThat(conversationAction.getAction()).isNull();
545     String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
546     assertThat(code).isEqualTo("12345");
547     assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
548   }
549 
550   @Test
testSuggestConversationActions_deduplicate()551   public void testSuggestConversationActions_deduplicate() throws IOException {
552     ConversationActions.Message message =
553         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
554             .setText("a@android.com b@android.com")
555             .build();
556     ConversationActions.Request request =
557         new ConversationActions.Request.Builder(Collections.singletonList(message))
558             .setMaxSuggestions(3)
559             .build();
560 
561     ConversationActions conversationActions =
562         classifier.suggestConversationActions(null, null, request);
563 
564     assertThat(conversationActions.getConversationActions()).isEmpty();
565   }
566 
567   @Test
testUseCachedAnnotatorModelDisabled()568   public void testUseCachedAnnotatorModelDisabled() throws IOException {
569     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
570 
571     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
572     ModelFile annotatorModelA =
573         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
574     ModelFile annotatorModelB =
575         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
576 
577     String englishText = "You can reach me on +12122537077.";
578     String classifiedText = "+12122537077";
579     TextClassification.Request request =
580         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
581 
582     // Check modelFileA v701
583     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
584         .thenReturn(annotatorModelA);
585     TextClassification classificationA = classifier.classifyText(null, null, request);
586 
587     assertThat(classificationA.getId()).contains("v701");
588     assertThat(classificationA.getText()).contains(classifiedText);
589     assertArrayEquals(
590         new int[] {0, 0, 0, 0},
591         new int[] {
592           annotatorModelCache.putCount(),
593           annotatorModelCache.evictionCount(),
594           annotatorModelCache.hitCount(),
595           annotatorModelCache.missCount()
596         });
597 
598     // Check modelFileB v801
599     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
600         .thenReturn(annotatorModelB);
601     TextClassification classificationB = classifier.classifyText(null, null, request);
602 
603     assertThat(classificationB.getId()).contains("v801");
604     assertThat(classificationB.getText()).contains(classifiedText);
605     assertArrayEquals(
606         new int[] {0, 0, 0, 0},
607         new int[] {
608           annotatorModelCache.putCount(),
609           annotatorModelCache.evictionCount(),
610           annotatorModelCache.hitCount(),
611           annotatorModelCache.missCount()
612         });
613 
614     // Reload modelFileA v701
615     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
616         .thenReturn(annotatorModelA);
617     TextClassification classificationAcached = classifier.classifyText(null, null, request);
618 
619     assertThat(classificationAcached.getId()).contains("v701");
620     assertThat(classificationAcached.getText()).contains(classifiedText);
621     assertArrayEquals(
622         new int[] {0, 0, 0, 0},
623         new int[] {
624           annotatorModelCache.putCount(),
625           annotatorModelCache.evictionCount(),
626           annotatorModelCache.hitCount(),
627           annotatorModelCache.missCount()
628         });
629   }
630 
631   @Test
testUseCachedAnnotatorModelEnabled()632   public void testUseCachedAnnotatorModelEnabled() throws IOException {
633     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
634     deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true);
635 
636     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
637     ModelFile annotatorModelA =
638         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
639     ModelFile annotatorModelB =
640         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
641 
642     String englishText = "You can reach me on +12122537077.";
643     String classifiedText = "+12122537077";
644     TextClassification.Request request =
645         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
646 
647     // Check modelFileA v701
648     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
649         .thenReturn(annotatorModelA);
650     TextClassification classification = classifier.classifyText(null, null, request);
651 
652     assertThat(classification.getId()).contains("v701");
653     assertThat(classification.getText()).contains(classifiedText);
654     assertArrayEquals(
655         new int[] {1, 0, 0, 1},
656         new int[] {
657           annotatorModelCache.putCount(),
658           annotatorModelCache.evictionCount(),
659           annotatorModelCache.hitCount(),
660           annotatorModelCache.missCount()
661         });
662 
663     // Check modelFileB v801
664     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
665         .thenReturn(annotatorModelB);
666     TextClassification classificationB = classifier.classifyText(null, null, request);
667 
668     assertThat(classificationB.getId()).contains("v801");
669     assertThat(classificationB.getText()).contains(classifiedText);
670     assertArrayEquals(
671         new int[] {2, 0, 0, 2},
672         new int[] {
673           annotatorModelCache.putCount(),
674           annotatorModelCache.evictionCount(),
675           annotatorModelCache.hitCount(),
676           annotatorModelCache.missCount()
677         });
678 
679     // Reload modelFileA v701
680     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
681         .thenReturn(annotatorModelA);
682     TextClassification classificationAcached = classifier.classifyText(null, null, request);
683 
684     assertThat(classificationAcached.getId()).contains("v701");
685     assertThat(classificationAcached.getText()).contains(classifiedText);
686     assertArrayEquals(
687         new int[] {2, 0, 1, 2},
688         new int[] {
689           annotatorModelCache.putCount(),
690           annotatorModelCache.evictionCount(),
691           annotatorModelCache.hitCount(),
692           annotatorModelCache.missCount()
693         });
694   }
695 
assertNoPackageInfoInExtras(Intent intent)696   private static void assertNoPackageInfoInExtras(Intent intent) {
697     assertThat(intent.getComponent()).isNull();
698     assertThat(intent.getPackage()).isNull();
699   }
700 
isTextSelection( final int startIndex, final int endIndex, final String type)701   private static Matcher<TextSelection> isTextSelection(
702       final int startIndex, final int endIndex, final String type) {
703     return new BaseMatcher<TextSelection>() {
704       @Override
705       public boolean matches(Object o) {
706         if (o instanceof TextSelection) {
707           TextSelection selection = (TextSelection) o;
708           return startIndex == selection.getSelectionStartIndex()
709               && endIndex == selection.getSelectionEndIndex()
710               && typeMatches(selection, type);
711         }
712         return false;
713       }
714 
715       private boolean typeMatches(TextSelection selection, String type) {
716         return type == null
717             || (selection.getEntityCount() > 0
718                 && type.trim().equalsIgnoreCase(selection.getEntity(0)));
719       }
720 
721       @Override
722       public void describeTo(Description description) {
723         description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
724       }
725     };
726   }
727 
728   private static Matcher<TextLinks> isTextLinksContaining(
729       final String text, final String substring, final String type) {
730     return new BaseMatcher<TextLinks>() {
731 
732       @Override
733       public void describeTo(Description description) {
734         description
735             .appendText("text=")
736             .appendValue(text)
737             .appendText(", substring=")
738             .appendValue(substring)
739             .appendText(", type=")
740             .appendValue(type);
741       }
742 
743       @Override
744       public boolean matches(Object o) {
745         if (o instanceof TextLinks) {
746           for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
747             if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
748               return type.equals(link.getEntity(0));
749             }
750           }
751         }
752         return false;
753       }
754     };
755   }
756 
757   private static Matcher<TextClassification> isTextClassification(
758       final String text, final String type) {
759     return new BaseMatcher<TextClassification>() {
760       @Override
761       public boolean matches(Object o) {
762         if (o instanceof TextClassification) {
763           TextClassification result = (TextClassification) o;
764           return text.equals(result.getText())
765               && result.getEntityCount() > 0
766               && type.equals(result.getEntity(0));
767         }
768         return false;
769       }
770 
771       @Override
772       public void describeTo(Description description) {
773         description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
774       }
775     };
776   }
777 
778   private static Matcher<TextClassification> containsIntentWithAction(final String action) {
779     return new BaseMatcher<TextClassification>() {
780       @Override
781       public boolean matches(Object o) {
782         if (o instanceof TextClassification) {
783           TextClassification result = (TextClassification) o;
784           return ExtrasUtils.findAction(result, action) != null;
785         }
786         return false;
787       }
788 
789       @Override
790       public void describeTo(Description description) {
791         description.appendText("intent action=").appendValue(action);
792       }
793     };
794   }
795 
796   private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
797     return new BaseMatcher<TextLanguage>() {
798       @Override
799       public boolean matches(Object o) {
800         if (o instanceof TextLanguage) {
801           TextLanguage result = (TextLanguage) o;
802           return result.getLocaleHypothesisCount() > 0
803               && languageTag.equals(result.getLocale(0).toLanguageTag());
804         }
805         return false;
806       }
807 
808       @Override
809       public void describeTo(Description description) {
810         description.appendText("locale=").appendValue(languageTag);
811       }
812     };
813   }
814 
815   private static Matcher<ConversationAction> isConversationAction(String actionType) {
816     return new BaseMatcher<ConversationAction>() {
817       @Override
818       public boolean matches(Object o) {
819         if (!(o instanceof ConversationAction)) {
820           return false;
821         }
822         ConversationAction conversationAction = (ConversationAction) o;
823         if (!actionType.equals(conversationAction.getType())) {
824           return false;
825         }
826         if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
827           if (conversationAction.getTextReply() == null) {
828             return false;
829           }
830         }
831         if (conversationAction.getConfidenceScore() < 0
832             || conversationAction.getConfidenceScore() > 1) {
833           return false;
834         }
835         return true;
836       }
837 
838       @Override
839       public void describeTo(Description description) {
840         description.appendText("actionType=").appendValue(actionType);
841       }
842     };
843   }
844 }
845