• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.hamcrest.CoreMatchers.not;
21 import static org.junit.Assert.assertArrayEquals;
22 import static org.junit.Assert.assertEquals;
23 import static org.junit.Assert.assertThat;
24 import static org.junit.Assert.assertTrue;
25 import static org.mockito.Mockito.any;
26 import static org.mockito.Mockito.eq;
27 import static org.mockito.Mockito.verify;
28 import static org.mockito.Mockito.when;
29 import static org.testng.Assert.expectThrows;
30 
31 import android.app.RemoteAction;
32 import android.content.Context;
33 import android.content.Intent;
34 import android.net.Uri;
35 import android.os.Bundle;
36 import android.os.LocaleList;
37 import android.permission.flags.Flags;
38 import android.platform.test.annotations.RequiresFlagsEnabled;
39 import android.text.Spannable;
40 import android.text.SpannableString;
41 import android.view.textclassifier.ConversationAction;
42 import android.view.textclassifier.ConversationActions;
43 import android.view.textclassifier.TextClassification;
44 import android.view.textclassifier.TextClassifier;
45 import android.view.textclassifier.TextLanguage;
46 import android.view.textclassifier.TextLinks;
47 import android.view.textclassifier.TextSelection;
48 import androidx.collection.LruCache;
49 import androidx.test.ext.junit.runners.AndroidJUnit4;
50 import androidx.test.filters.SdkSuppress;
51 import androidx.test.filters.SmallTest;
52 import com.android.textclassifier.common.ModelFile;
53 import com.android.textclassifier.common.ModelType;
54 import com.android.textclassifier.common.TextClassifierSettings;
55 import com.android.textclassifier.testing.FakeContextBuilder;
56 import com.android.textclassifier.testing.TestingDeviceConfig;
57 import com.android.textclassifier.utils.TextClassifierUtils;
58 
59 import com.google.android.textclassifier.AnnotatorModel;
60 import com.google.common.collect.ImmutableList;
61 import java.io.IOException;
62 import java.util.ArrayList;
63 import java.util.Arrays;
64 import java.util.Collections;
65 import java.util.List;
66 import org.hamcrest.BaseMatcher;
67 import org.hamcrest.Description;
68 import org.hamcrest.Matcher;
69 import org.junit.Assume;
70 import org.junit.Before;
71 import org.junit.Test;
72 import org.junit.runner.RunWith;
73 import org.mockito.Mock;
74 import org.mockito.MockitoAnnotations;
75 
76 @SmallTest
77 @RunWith(AndroidJUnit4.class)
78 public class TextClassifierImplTest {
79 
80   private static final String TYPE_COPY = "copy";
81   private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
82   private static final String NO_TYPE = null;
83 
84   @Mock private ModelFileManager modelFileManager;
85 
86   private Context context;
87   private TestingDeviceConfig deviceConfig;
88   private TextClassifierSettings settings;
89   private LruCache<ModelFile, AnnotatorModel> annotatorModelCache;
90   private TextClassifierImpl classifier;
91 
92   @Before
setup()93   public void setup() throws IOException {
94     MockitoAnnotations.initMocks(this);
95     this.context =
96         new FakeContextBuilder()
97             .setAllIntentComponent(FakeContextBuilder.DEFAULT_COMPONENT)
98             .setAppLabel(FakeContextBuilder.DEFAULT_COMPONENT.getPackageName(), "Test app")
99             .build();
100     this.deviceConfig = new TestingDeviceConfig();
101     this.settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false);
102     this.annotatorModelCache = new LruCache<>(2);
103     this.classifier =
104         new TextClassifierImpl(context, settings, modelFileManager, annotatorModelCache);
105 
106     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
107         .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
108     when(modelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
109         .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
110     when(modelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
111         .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
112   }
113 
114   @Test
testSuggestSelection()115   public void testSuggestSelection() throws IOException {
116     String text = "Contact me at droid@android.com";
117     String selected = "droid";
118     String suggested = "droid@android.com";
119     int startIndex = text.indexOf(selected);
120     int endIndex = startIndex + selected.length();
121     int smartStartIndex = text.indexOf(suggested);
122     int smartEndIndex = smartStartIndex + suggested.length();
123     TextSelection.Request request =
124         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
125 
126     TextSelection selection = classifier.suggestSelection(null, null, request);
127     assertThat(
128         selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
129   }
130 
131   @Test
testSuggestSelection_localePreferenceIsPassedToModelFileManager()132   public void testSuggestSelection_localePreferenceIsPassedToModelFileManager() throws IOException {
133     String text = "Contact me at droid@android.com";
134     String selected = "droid";
135     String suggested = "droid@android.com";
136     int startIndex = text.indexOf(selected);
137     int endIndex = startIndex + selected.length();
138     int smartStartIndex = text.indexOf(suggested);
139     int smartEndIndex = smartStartIndex + suggested.length();
140     TextSelection.Request request =
141         new TextSelection.Request.Builder(text, startIndex, endIndex)
142             .setDefaultLocales(LOCALES)
143             .build();
144 
145     classifier.suggestSelection(null, null, request);
146     verify(modelFileManager).findBestModelFile(eq(ModelType.ANNOTATOR), eq(LOCALES), any());
147   }
148 
149   @Test
testSuggestSelection_url()150   public void testSuggestSelection_url() throws IOException {
151     String text = "Visit http://www.android.com for more information";
152     String selected = "http";
153     String suggested = "http://www.android.com";
154     int startIndex = text.indexOf(selected);
155     int endIndex = startIndex + selected.length();
156     int smartStartIndex = text.indexOf(suggested);
157     int smartEndIndex = smartStartIndex + suggested.length();
158     TextSelection.Request request =
159         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
160 
161     TextSelection selection = classifier.suggestSelection(null, null, request);
162     assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
163   }
164 
165   @Test
testSmartSelection_withEmoji()166   public void testSmartSelection_withEmoji() throws IOException {
167     String text = "\uD83D\uDE02 Hello.";
168     String selected = "Hello";
169     int startIndex = text.indexOf(selected);
170     int endIndex = startIndex + selected.length();
171     TextSelection.Request request =
172         new TextSelection.Request.Builder(text, startIndex, endIndex).build();
173 
174     TextSelection selection = classifier.suggestSelection(null, null, request);
175     assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
176   }
177 
178   @SdkSuppress(minSdkVersion = 31, codeName = "S")
179   @Test
testSuggestSelection_includeTextClassification()180   public void testSuggestSelection_includeTextClassification() throws IOException {
181     String text = "Visit http://www.android.com for more information";
182     String suggested = "http://www.android.com";
183     int startIndex = text.indexOf(suggested);
184     TextSelection.Request request =
185         new TextSelection.Request.Builder(text, startIndex, /* endIndex= */ startIndex + 1)
186             .setIncludeTextClassification(true)
187             .build();
188 
189     TextSelection selection = classifier.suggestSelection(null, null, request);
190 
191     assertThat(
192         selection.getTextClassification(),
193         isTextClassification(suggested, TextClassifier.TYPE_URL));
194     assertThat(selection.getTextClassification(), containsIntentWithAction(Intent.ACTION_VIEW));
195   }
196 
197   @SdkSuppress(minSdkVersion = 31, codeName = "S")
198   @Test
testSuggestSelection_notIncludeTextClassification()199   public void testSuggestSelection_notIncludeTextClassification() throws IOException {
200     String text = "Visit http://www.android.com for more information";
201     TextSelection.Request request =
202         new TextSelection.Request.Builder(text, /* startIndex= */ 0, /* endIndex= */ 4)
203             .setIncludeTextClassification(false)
204             .build();
205 
206     TextSelection selection = classifier.suggestSelection(null, null, request);
207 
208     assertThat(selection.getTextClassification()).isNull();
209   }
210 
211   @Test
testClassifyText()212   public void testClassifyText() throws IOException {
213     String text = "Contact me at droid@android.com";
214     String classifiedText = "droid@android.com";
215     int startIndex = text.indexOf(classifiedText);
216     int endIndex = startIndex + classifiedText.length();
217     TextClassification.Request request =
218         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
219 
220     TextClassification classification =
221         classifier.classifyText(/* sessionId= */ null, null, request);
222     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
223   }
224 
225   @Test
testClassifyText_url()226   public void testClassifyText_url() throws IOException {
227     String text = "Visit www.android.com for more information";
228     String classifiedText = "www.android.com";
229     int startIndex = text.indexOf(classifiedText);
230     int endIndex = startIndex + classifiedText.length();
231     TextClassification.Request request =
232         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
233 
234     TextClassification classification = classifier.classifyText(null, null, request);
235     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
236     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
237   }
238 
239   @Test
testClassifyText_address()240   public void testClassifyText_address() throws IOException {
241     String text = "Brandschenkestrasse 110, Zürich, Switzerland";
242     TextClassification.Request request =
243         new TextClassification.Request.Builder(text, 0, text.length()).build();
244 
245     TextClassification classification = classifier.classifyText(null, null, request);
246     assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
247   }
248 
249   @Test
testClassifyText_url_inCaps()250   public void testClassifyText_url_inCaps() throws IOException {
251     String text = "Visit HTTP://ANDROID.COM for more information";
252     String classifiedText = "HTTP://ANDROID.COM";
253     int startIndex = text.indexOf(classifiedText);
254     int endIndex = startIndex + classifiedText.length();
255     TextClassification.Request request =
256         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
257 
258     TextClassification classification = classifier.classifyText(null, null, request);
259     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
260     assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
261   }
262 
263   @Test
testClassifyText_date()264   public void testClassifyText_date() throws IOException {
265     String text = "Let's meet on January 9, 2018.";
266     String classifiedText = "January 9, 2018";
267     int startIndex = text.indexOf(classifiedText);
268     int endIndex = startIndex + classifiedText.length();
269     TextClassification.Request request =
270         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
271 
272     TextClassification classification = classifier.classifyText(null, null, request);
273     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
274     Bundle extras = classification.getExtras();
275     List<Bundle> entities = ExtrasUtils.getEntities(extras);
276     assertThat(entities).hasSize(1);
277     assertThat(ExtrasUtils.getEntityType(entities.get(0))).isEqualTo(TextClassifier.TYPE_DATE);
278     ArrayList<Intent> actionsIntents = ExtrasUtils.getActionsIntents(classification);
279     actionsIntents.forEach(TextClassifierImplTest::assertNoPackageInfoInExtras);
280   }
281 
282   @Test
testClassifyText_datetime()283   public void testClassifyText_datetime() throws IOException {
284     String text = "Let's meet 2018/01/01 10:30:20.";
285     String classifiedText = "2018/01/01 10:30:20";
286     int startIndex = text.indexOf(classifiedText);
287     int endIndex = startIndex + classifiedText.length();
288     TextClassification.Request request =
289         new TextClassification.Request.Builder(text, startIndex, endIndex).build();
290 
291     TextClassification classification = classifier.classifyText(null, null, request);
292     assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
293   }
294 
295   @Test
testClassifyText_foreignText()296   public void testClassifyText_foreignText() throws IOException {
297     LocaleList originalLocales = LocaleList.getDefault();
298     LocaleList.setDefault(LocaleList.forLanguageTags("en"));
299     String japaneseText = "これは日本語のテキストです";
300     TextClassification.Request request =
301         new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length()).build();
302 
303     TextClassification classification = classifier.classifyText(null, null, request);
304     RemoteAction translateAction = classification.getActions().get(0);
305     assertEquals(1, classification.getActions().size());
306     assertEquals(Intent.ACTION_TRANSLATE, classification.getIntent().getAction());
307 
308     assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
309     Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
310     assertNoPackageInfoInExtras(intent);
311     assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
312     Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
313     assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
314     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
315     assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
316     assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
317     assertEquals("ja", ExtrasUtils.getTopLanguage(intent).first);
318 
319     LocaleList.setDefault(originalLocales);
320   }
321 
322   @Test
testGenerateLinks_phone()323   public void testGenerateLinks_phone() throws IOException {
324     String text = "The number is +12122537077. See you tonight!";
325     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
326     assertThat(
327         classifier.generateLinks(null, null, request),
328         isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
329   }
330 
331   @Test
332   @RequiresFlagsEnabled(Flags.FLAG_ENABLE_OTP_IN_TEXT_CLASSIFIERS)
testGenerateLinks_otp()333   public void testGenerateLinks_otp() throws IOException {
334     Assume.assumeTrue(TextClassifierUtils.isOtpClassificationEnabled());
335     String text = "Your OTP code is 123456";
336     List<String> included = new ArrayList<>(List.of(TextClassifier.TYPE_OTP));
337     TextClassifier.EntityConfig config =
338             new TextClassifier.EntityConfig.Builder()
339                     .setIncludedTypes(included)
340                     .includeTypesFromTextClassifier(false)
341                     .build();
342     TextLinks.Request request = new TextLinks.Request.Builder(text).setEntityConfig(config).build();
343     assertThat(
344             classifier.generateLinks(null, null, request),
345             isTextLinksContaining(text, "", TextClassifier.TYPE_OTP));
346   }
347 
348   @Test
349   @RequiresFlagsEnabled(Flags.FLAG_ENABLE_OTP_IN_TEXT_CLASSIFIERS)
testGenerateLinks_otpTypeNotInRequest()350   public void testGenerateLinks_otpTypeNotInRequest() throws IOException {
351     Assume.assumeTrue(TextClassifierUtils.isOtpClassificationEnabled());
352     String text = "Your OTP code is 123456";
353     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
354     assertThat(
355             classifier.generateLinks(null, null, request).getLinks()).isEmpty();
356   }
357 
358   @Test
testGenerateLinks_exclude()359   public void testGenerateLinks_exclude() throws IOException {
360     String text = "The number is +12122537077. See you tonight!";
361     List<String> hints = ImmutableList.of();
362     List<String> included = ImmutableList.of();
363     List<String> excluded = Arrays.asList(TextClassifier.TYPE_PHONE);
364     TextLinks.Request request =
365         new TextLinks.Request.Builder(text)
366             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
367             .build();
368     assertThat(
369         classifier.generateLinks(null, null, request),
370         not(isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE)));
371   }
372 
373   @Test
testGenerateLinks_explicit_address()374   public void testGenerateLinks_explicit_address() throws IOException {
375     String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
376     List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
377     TextLinks.Request request =
378         new TextLinks.Request.Builder(text)
379             .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
380             .build();
381     assertThat(
382         classifier.generateLinks(null, null, request),
383         isTextLinksContaining(
384             text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
385   }
386 
387   @Test
testGenerateLinks_exclude_override()388   public void testGenerateLinks_exclude_override() throws IOException {
389     String text = "You want apple@banana.com. See you tonight!";
390     List<String> hints = ImmutableList.of();
391     List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
392     List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
393     TextLinks.Request request =
394         new TextLinks.Request.Builder(text)
395             .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
396             .build();
397     assertThat(
398         classifier.generateLinks(null, null, request),
399         not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
400   }
401 
402   @Test
testGenerateLinks_maxLength()403   public void testGenerateLinks_maxLength() throws IOException {
404     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
405     Arrays.fill(manySpaces, ' ');
406     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
407     TextLinks links = classifier.generateLinks(null, null, request);
408     assertTrue(links.getLinks().isEmpty());
409   }
410 
411   @Test
testApplyLinks_unsupportedCharacter()412   public void testApplyLinks_unsupportedCharacter() throws IOException {
413     Spannable url = new SpannableString("\u202Emoc.diordna.com");
414     TextLinks.Request request = new TextLinks.Request.Builder(url).build();
415     assertEquals(
416         TextLinks.STATUS_UNSUPPORTED_CHARACTER,
417         classifier.generateLinks(null, null, request).apply(url, 0, null));
418   }
419 
420   @Test
testGenerateLinks_tooLong()421   public void testGenerateLinks_tooLong() {
422     char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
423     Arrays.fill(manySpaces, ' ');
424     TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
425     expectThrows(
426         IllegalArgumentException.class, () -> classifier.generateLinks(null, null, request));
427   }
428 
429   @Test
testGenerateLinks_entityData()430   public void testGenerateLinks_entityData() throws IOException {
431     String text = "The number is +12122537077.";
432     Bundle extras = new Bundle();
433     ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
434     TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
435 
436     TextLinks textLinks = classifier.generateLinks(null, null, request);
437 
438     assertThat(textLinks.getLinks()).hasSize(1);
439     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
440     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
441     assertThat(entities).hasSize(1);
442     Bundle entity = entities.get(0);
443     assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
444   }
445 
446   @Test
testGenerateLinks_entityData_disabled()447   public void testGenerateLinks_entityData_disabled() throws IOException {
448     String text = "The number is +12122537077.";
449     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
450 
451     TextLinks textLinks = classifier.generateLinks(null, null, request);
452 
453     assertThat(textLinks.getLinks()).hasSize(1);
454     TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
455     List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
456     assertThat(entities).isNull();
457   }
458 
459   @Test
testDetectLanguage()460   public void testDetectLanguage() throws IOException {
461     String text = "This is English text";
462     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
463     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
464     assertThat(textLanguage, isTextLanguage("en"));
465   }
466 
467   @Test
testDetectLanguage_japanese()468   public void testDetectLanguage_japanese() throws IOException {
469     String text = "これは日本語のテキストです";
470     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
471     TextLanguage textLanguage = classifier.detectLanguage(null, null, request);
472     assertThat(textLanguage, isTextLanguage("ja"));
473   }
474 
475   @Test
testSuggestConversationActions_textReplyOnly_maxOne()476   public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
477     ConversationActions.Message message =
478         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
479             .setText("Where are you?")
480             .build();
481     TextClassifier.EntityConfig typeConfig =
482         new TextClassifier.EntityConfig.Builder()
483             .includeTypesFromTextClassifier(false)
484             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
485             .build();
486     ConversationActions.Request request =
487         new ConversationActions.Request.Builder(Collections.singletonList(message))
488             .setMaxSuggestions(1)
489             .setTypeConfig(typeConfig)
490             .build();
491 
492     ConversationActions conversationActions =
493         classifier.suggestConversationActions(null, null, request);
494     assertThat(conversationActions.getConversationActions()).hasSize(1);
495     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
496     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
497     assertThat(conversationAction.getTextReply()).isNotNull();
498   }
499 
500   @Test
testSuggestConversationActions_textReplyOnly_noMax()501   public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
502     ConversationActions.Message message =
503         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
504             .setText("Where are you?")
505             .build();
506     TextClassifier.EntityConfig typeConfig =
507         new TextClassifier.EntityConfig.Builder()
508             .includeTypesFromTextClassifier(false)
509             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
510             .build();
511     ConversationActions.Request request =
512         new ConversationActions.Request.Builder(Collections.singletonList(message))
513             .setTypeConfig(typeConfig)
514             .build();
515 
516     ConversationActions conversationActions =
517         classifier.suggestConversationActions(null, null, request);
518     assertTrue(conversationActions.getConversationActions().size() > 1);
519     for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
520       assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
521     }
522   }
523 
524   @Test
testSuggestConversationActions_openUrl()525   public void testSuggestConversationActions_openUrl() throws IOException {
526     ConversationActions.Message message =
527         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
528             .setText("Check this out: https://www.android.com")
529             .build();
530     TextClassifier.EntityConfig typeConfig =
531         new TextClassifier.EntityConfig.Builder()
532             .includeTypesFromTextClassifier(false)
533             .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
534             .build();
535     ConversationActions.Request request =
536         new ConversationActions.Request.Builder(Collections.singletonList(message))
537             .setMaxSuggestions(1)
538             .setTypeConfig(typeConfig)
539             .build();
540 
541     ConversationActions conversationActions =
542         classifier.suggestConversationActions(null, null, request);
543     assertThat(conversationActions.getConversationActions()).hasSize(1);
544     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
545     assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
546     Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
547     assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
548     assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
549     assertNoPackageInfoInExtras(actionIntent);
550   }
551 
552   @Test
testSuggestConversationActions_copy()553   public void testSuggestConversationActions_copy() throws IOException {
554     ConversationActions.Message message =
555         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
556             .setText("Authentication code: 12345")
557             .build();
558     TextClassifier.EntityConfig typeConfig =
559         new TextClassifier.EntityConfig.Builder()
560             .includeTypesFromTextClassifier(false)
561             .setIncludedTypes(Collections.singletonList(TYPE_COPY))
562             .build();
563     ConversationActions.Request request =
564         new ConversationActions.Request.Builder(Collections.singletonList(message))
565             .setMaxSuggestions(1)
566             .setTypeConfig(typeConfig)
567             .build();
568 
569     ConversationActions conversationActions =
570         classifier.suggestConversationActions(null, null, request);
571     assertThat(conversationActions.getConversationActions()).hasSize(1);
572     ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
573     assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
574     assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
575     assertThat(conversationAction.getAction()).isNull();
576     String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
577     assertThat(code).isEqualTo("12345");
578     assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras())).isNotEmpty();
579   }
580 
581   @Test
testSuggestConversationActions_deduplicate()582   public void testSuggestConversationActions_deduplicate() throws IOException {
583     ConversationActions.Message message =
584         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
585             .setText("a@android.com b@android.com")
586             .build();
587     ConversationActions.Request request =
588         new ConversationActions.Request.Builder(Collections.singletonList(message))
589             .setMaxSuggestions(3)
590             .build();
591 
592     ConversationActions conversationActions =
593         classifier.suggestConversationActions(null, null, request);
594 
595     assertThat(conversationActions.getConversationActions()).isEmpty();
596   }
597 
598   @Test
testUseCachedAnnotatorModelDisabled()599   public void testUseCachedAnnotatorModelDisabled() throws IOException {
600     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
601 
602     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
603     ModelFile annotatorModelA =
604         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
605     ModelFile annotatorModelB =
606         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
607 
608     String englishText = "You can reach me on +12122537077.";
609     String classifiedText = "+12122537077";
610     TextClassification.Request request =
611         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
612 
613     // Check modelFileA v701
614     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
615         .thenReturn(annotatorModelA);
616     TextClassification classificationA = classifier.classifyText(null, null, request);
617 
618     assertThat(classificationA.getId()).contains("v701");
619     assertThat(classificationA.getText()).contains(classifiedText);
620     assertArrayEquals(
621         new int[] {0, 0, 0, 0},
622         new int[] {
623           annotatorModelCache.putCount(),
624           annotatorModelCache.evictionCount(),
625           annotatorModelCache.hitCount(),
626           annotatorModelCache.missCount()
627         });
628 
629     // Check modelFileB v801
630     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
631         .thenReturn(annotatorModelB);
632     TextClassification classificationB = classifier.classifyText(null, null, request);
633 
634     assertThat(classificationB.getId()).contains("v801");
635     assertThat(classificationB.getText()).contains(classifiedText);
636     assertArrayEquals(
637         new int[] {0, 0, 0, 0},
638         new int[] {
639           annotatorModelCache.putCount(),
640           annotatorModelCache.evictionCount(),
641           annotatorModelCache.hitCount(),
642           annotatorModelCache.missCount()
643         });
644 
645     // Reload modelFileA v701
646     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
647         .thenReturn(annotatorModelA);
648     TextClassification classificationAcached = classifier.classifyText(null, null, request);
649 
650     assertThat(classificationAcached.getId()).contains("v701");
651     assertThat(classificationAcached.getText()).contains(classifiedText);
652     assertArrayEquals(
653         new int[] {0, 0, 0, 0},
654         new int[] {
655           annotatorModelCache.putCount(),
656           annotatorModelCache.evictionCount(),
657           annotatorModelCache.hitCount(),
658           annotatorModelCache.missCount()
659         });
660   }
661 
662   @Test
testUseCachedAnnotatorModelEnabled()663   public void testUseCachedAnnotatorModelEnabled() throws IOException {
664     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
665     deviceConfig.setConfig(TextClassifierSettings.MULTI_ANNOTATOR_CACHE_ENABLED, true);
666 
667     String annotatorFilePath = TestDataUtils.getTestAnnotatorModelFile().getPath();
668     ModelFile annotatorModelA =
669         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 701, "en", false);
670     ModelFile annotatorModelB =
671         new ModelFile(ModelType.ANNOTATOR, annotatorFilePath, 801, "en", false);
672 
673     String englishText = "You can reach me on +12122537077.";
674     String classifiedText = "+12122537077";
675     TextClassification.Request request =
676         new TextClassification.Request.Builder(englishText, 0, englishText.length()).build();
677 
678     // Check modelFileA v701
679     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
680         .thenReturn(annotatorModelA);
681     TextClassification classification = classifier.classifyText(null, null, request);
682 
683     assertThat(classification.getId()).contains("v701");
684     assertThat(classification.getText()).contains(classifiedText);
685     assertArrayEquals(
686         new int[] {1, 0, 0, 1},
687         new int[] {
688           annotatorModelCache.putCount(),
689           annotatorModelCache.evictionCount(),
690           annotatorModelCache.hitCount(),
691           annotatorModelCache.missCount()
692         });
693 
694     // Check modelFileB v801
695     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
696         .thenReturn(annotatorModelB);
697     TextClassification classificationB = classifier.classifyText(null, null, request);
698 
699     assertThat(classificationB.getId()).contains("v801");
700     assertThat(classificationB.getText()).contains(classifiedText);
701     assertArrayEquals(
702         new int[] {2, 0, 0, 2},
703         new int[] {
704           annotatorModelCache.putCount(),
705           annotatorModelCache.evictionCount(),
706           annotatorModelCache.hitCount(),
707           annotatorModelCache.missCount()
708         });
709 
710     // Reload modelFileA v701
711     when(modelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
712         .thenReturn(annotatorModelA);
713     TextClassification classificationAcached = classifier.classifyText(null, null, request);
714 
715     assertThat(classificationAcached.getId()).contains("v701");
716     assertThat(classificationAcached.getText()).contains(classifiedText);
717     assertArrayEquals(
718         new int[] {2, 0, 1, 2},
719         new int[] {
720           annotatorModelCache.putCount(),
721           annotatorModelCache.evictionCount(),
722           annotatorModelCache.hitCount(),
723           annotatorModelCache.missCount()
724         });
725   }
726 
assertNoPackageInfoInExtras(Intent intent)727   private static void assertNoPackageInfoInExtras(Intent intent) {
728     assertThat(intent.getComponent()).isNull();
729     assertThat(intent.getPackage()).isNull();
730   }
731 
isTextSelection( final int startIndex, final int endIndex, final String type)732   private static Matcher<TextSelection> isTextSelection(
733       final int startIndex, final int endIndex, final String type) {
734     return new BaseMatcher<TextSelection>() {
735       @Override
736       public boolean matches(Object o) {
737         if (o instanceof TextSelection) {
738           TextSelection selection = (TextSelection) o;
739           return startIndex == selection.getSelectionStartIndex()
740               && endIndex == selection.getSelectionEndIndex()
741               && typeMatches(selection, type);
742         }
743         return false;
744       }
745 
746       private boolean typeMatches(TextSelection selection, String type) {
747         return type == null
748             || (selection.getEntityCount() > 0
749                 && type.trim().equalsIgnoreCase(selection.getEntity(0)));
750       }
751 
752       @Override
753       public void describeTo(Description description) {
754         description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
755       }
756     };
757   }
758 
759   private static Matcher<TextLinks> isTextLinksContaining(
760       final String text, final String substring, final String type) {
761     return new BaseMatcher<TextLinks>() {
762 
763       @Override
764       public void describeTo(Description description) {
765         description
766             .appendText("text=")
767             .appendValue(text)
768             .appendText(", substring=")
769             .appendValue(substring)
770             .appendText(", type=")
771             .appendValue(type);
772       }
773 
774       @Override
775       public boolean matches(Object o) {
776         if (o instanceof TextLinks) {
777           for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
778             if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
779               return type.equals(link.getEntity(0));
780             }
781           }
782         }
783         return false;
784       }
785     };
786   }
787 
788   private static Matcher<TextClassification> isTextClassification(
789       final String text, final String type) {
790     return new BaseMatcher<TextClassification>() {
791       @Override
792       public boolean matches(Object o) {
793         if (o instanceof TextClassification) {
794           TextClassification result = (TextClassification) o;
795           return text.equals(result.getText())
796               && result.getEntityCount() > 0
797               && type.equals(result.getEntity(0));
798         }
799         return false;
800       }
801 
802       @Override
803       public void describeTo(Description description) {
804         description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
805       }
806     };
807   }
808 
809   private static Matcher<TextClassification> containsIntentWithAction(final String action) {
810     return new BaseMatcher<TextClassification>() {
811       @Override
812       public boolean matches(Object o) {
813         if (o instanceof TextClassification) {
814           TextClassification result = (TextClassification) o;
815           return ExtrasUtils.findAction(result, action) != null;
816         }
817         return false;
818       }
819 
820       @Override
821       public void describeTo(Description description) {
822         description.appendText("intent action=").appendValue(action);
823       }
824     };
825   }
826 
827   private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
828     return new BaseMatcher<TextLanguage>() {
829       @Override
830       public boolean matches(Object o) {
831         if (o instanceof TextLanguage) {
832           TextLanguage result = (TextLanguage) o;
833           return result.getLocaleHypothesisCount() > 0
834               && languageTag.equals(result.getLocale(0).toLanguageTag());
835         }
836         return false;
837       }
838 
839       @Override
840       public void describeTo(Description description) {
841         description.appendText("locale=").appendValue(languageTag);
842       }
843     };
844   }
845 
846   private static Matcher<ConversationAction> isConversationAction(String actionType) {
847     return new BaseMatcher<ConversationAction>() {
848       @Override
849       public boolean matches(Object o) {
850         if (!(o instanceof ConversationAction)) {
851           return false;
852         }
853         ConversationAction conversationAction = (ConversationAction) o;
854         if (!actionType.equals(conversationAction.getType())) {
855           return false;
856         }
857         if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
858           if (conversationAction.getTextReply() == null) {
859             return false;
860           }
861         }
862         if (conversationAction.getConfidenceScore() < 0
863             || conversationAction.getConfidenceScore() > 1) {
864           return false;
865         }
866         return true;
867       }
868 
869       @Override
870       public void describeTo(Description description) {
871         description.appendText("actionType=").appendValue(actionType);
872       }
873     };
874   }
875 }
876