• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.mockito.Mockito.any;
21 import static org.mockito.Mockito.eq;
22 import static org.mockito.Mockito.verify;
23 import static org.mockito.Mockito.when;
24 
25 import android.content.Context;
26 import android.os.CancellationSignal;
27 import android.service.textclassifier.TextClassifierService;
28 import android.view.textclassifier.ConversationAction;
29 import android.view.textclassifier.ConversationActions;
30 import android.view.textclassifier.TextClassification;
31 import android.view.textclassifier.TextClassifier;
32 import android.view.textclassifier.TextLanguage;
33 import android.view.textclassifier.TextLinks;
34 import android.view.textclassifier.TextLinks.TextLink;
35 import android.view.textclassifier.TextSelection;
36 import androidx.test.core.app.ApplicationProvider;
37 import androidx.test.ext.junit.runners.AndroidJUnit4;
38 import androidx.test.filters.SmallTest;
39 import com.android.internal.os.StatsdConfigProto.StatsdConfig;
40 import com.android.os.AtomsProto;
41 import com.android.os.AtomsProto.Atom;
42 import com.android.os.AtomsProto.TextClassifierApiUsageReported;
43 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
44 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
45 import com.android.textclassifier.common.ModelType;
46 import com.android.textclassifier.common.TextClassifierSettings;
47 import com.android.textclassifier.common.statsd.StatsdTestUtils;
48 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
49 import com.android.textclassifier.downloader.ModelDownloadManager;
50 import com.google.common.base.Preconditions;
51 import com.google.common.collect.ImmutableList;
52 import com.google.common.util.concurrent.ListeningExecutorService;
53 import com.google.common.util.concurrent.MoreExecutors;
54 import java.io.IOException;
55 import java.util.List;
56 import java.util.concurrent.Executor;
57 import java.util.stream.Collectors;
58 import org.junit.After;
59 import org.junit.Before;
60 import org.junit.Rule;
61 import org.junit.Test;
62 import org.junit.runner.RunWith;
63 import org.mockito.ArgumentCaptor;
64 import org.mockito.Mock;
65 import org.mockito.Mockito;
66 import org.mockito.junit.MockitoJUnit;
67 import org.mockito.junit.MockitoRule;
68 
69 @SmallTest
70 @RunWith(AndroidJUnit4.class)
71 public class DefaultTextClassifierServiceTest {
72 
73   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
74 
75   /** A statsd config ID, which is arbitrary. */
76   private static final long CONFIG_ID = 689777;
77 
78   private static final long SHORT_TIMEOUT_MS = 1000;
79 
80   private static final String SESSION_ID = "abcdef";
81 
82   private TestInjector testInjector;
83   private DefaultTextClassifierService defaultTextClassifierService;
84   @Mock private TextClassifierService.Callback<TextClassification> textClassificationCallback;
85   @Mock private TextClassifierService.Callback<TextSelection> textSelectionCallback;
86   @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback;
87   @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback;
88   @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback;
89   @Mock private ModelFileManager testModelFileManager;
90 
91   @Before
setup()92   public void setup() throws IOException {
93     testInjector =
94         new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager);
95     defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
96     defaultTextClassifierService.onCreate();
97 
98     when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
99         .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped());
100     when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any()))
101         .thenReturn(TestDataUtils.getLangIdModelFileWrapped());
102     when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any()))
103         .thenReturn(TestDataUtils.getTestActionsModelFileWrapped());
104   }
105 
106   @Before
setupStatsdTestUtils()107   public void setupStatsdTestUtils() throws Exception {
108     StatsdTestUtils.cleanup(CONFIG_ID);
109 
110     StatsdConfig.Builder builder =
111         StatsdConfig.newBuilder()
112             .setId(CONFIG_ID)
113             .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName());
114     StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER);
115     StatsdTestUtils.pushConfig(builder.build());
116   }
117 
118   @After
tearDown()119   public void tearDown() throws Exception {
120     StatsdTestUtils.cleanup(CONFIG_ID);
121   }
122 
123   @Test
classifyText_success()124   public void classifyText_success() throws Exception {
125     String text = "www.android.com";
126     TextClassification.Request request =
127         new TextClassification.Request.Builder(text, 0, text.length()).build();
128 
129     defaultTextClassifierService.onClassifyText(
130         TestingUtils.createTextClassificationSessionId(SESSION_ID),
131         request,
132         new CancellationSignal(),
133         textClassificationCallback);
134 
135     ArgumentCaptor<TextClassification> captor = ArgumentCaptor.forClass(TextClassification.class);
136     verify(textClassificationCallback).onSuccess(captor.capture());
137     assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
138     assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
139     verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.SUCCESS);
140   }
141 
142   @Test
suggestSelection_success()143   public void suggestSelection_success() throws Exception {
144     String text = "Visit http://www.android.com for more information";
145     String selected = "http";
146     String suggested = "http://www.android.com";
147     int start = text.indexOf(selected);
148     int end = start + suggested.length();
149     TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build();
150 
151     defaultTextClassifierService.onSuggestSelection(
152         TestingUtils.createTextClassificationSessionId(SESSION_ID),
153         request,
154         new CancellationSignal(),
155         textSelectionCallback);
156 
157     ArgumentCaptor<TextSelection> captor = ArgumentCaptor.forClass(TextSelection.class);
158     verify(textSelectionCallback).onSuccess(captor.capture());
159     assertThat(captor.getValue().getEntityCount()).isGreaterThan(0);
160     assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
161     verifyApiUsageLog(ApiType.SUGGEST_SELECTION, ResultType.SUCCESS);
162   }
163 
164   @Test
generateLinks_success()165   public void generateLinks_success() throws Exception {
166     String text = "Visit http://www.android.com for more information";
167     TextLinks.Request request = new TextLinks.Request.Builder(text).build();
168 
169     defaultTextClassifierService.onGenerateLinks(
170         TestingUtils.createTextClassificationSessionId(SESSION_ID),
171         request,
172         new CancellationSignal(),
173         textLinksCallback);
174 
175     ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class);
176     verify(textLinksCallback).onSuccess(captor.capture());
177     assertThat(captor.getValue().getLinks()).hasSize(1);
178     TextLink textLink = captor.getValue().getLinks().iterator().next();
179     assertThat(textLink.getEntityCount()).isGreaterThan(0);
180     assertThat(textLink.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
181     verifyApiUsageLog(ApiType.GENERATE_LINKS, ResultType.SUCCESS);
182   }
183 
184   @Test
detectLanguage_success()185   public void detectLanguage_success() throws Exception {
186     String text = "ピカチュウ";
187     TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
188 
189     defaultTextClassifierService.onDetectLanguage(
190         TestingUtils.createTextClassificationSessionId(SESSION_ID),
191         request,
192         new CancellationSignal(),
193         textLanguageCallback);
194 
195     ArgumentCaptor<TextLanguage> captor = ArgumentCaptor.forClass(TextLanguage.class);
196     verify(textLanguageCallback).onSuccess(captor.capture());
197     assertThat(captor.getValue().getLocaleHypothesisCount()).isGreaterThan(0);
198     assertThat(captor.getValue().getLocale(0).toLanguageTag()).isEqualTo("ja");
199     verifyApiUsageLog(ApiType.DETECT_LANGUAGES, ResultType.SUCCESS);
200   }
201 
202   @Test
suggestConversationActions_success()203   public void suggestConversationActions_success() throws Exception {
204     ConversationActions.Message message =
205         new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
206             .setText("Checkout www.android.com")
207             .build();
208     ConversationActions.Request request =
209         new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
210 
211     defaultTextClassifierService.onSuggestConversationActions(
212         TestingUtils.createTextClassificationSessionId(SESSION_ID),
213         request,
214         new CancellationSignal(),
215         conversationActionsCallback);
216 
217     ArgumentCaptor<ConversationActions> captor = ArgumentCaptor.forClass(ConversationActions.class);
218     verify(conversationActionsCallback).onSuccess(captor.capture());
219     List<ConversationAction> conversationActions = captor.getValue().getConversationActions();
220     assertThat(conversationActions.size()).isGreaterThan(0);
221     assertThat(conversationActions.get(0).getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
222     verifyApiUsageLog(ApiType.SUGGEST_CONVERSATION_ACTIONS, ResultType.SUCCESS);
223   }
224 
225   @Test
missingModelFile_onFailureShouldBeCalled()226   public void missingModelFile_onFailureShouldBeCalled() throws Exception {
227     when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any()))
228         .thenReturn(null);
229     defaultTextClassifierService.onCreate();
230 
231     TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
232     defaultTextClassifierService.onClassifyText(
233         TestingUtils.createTextClassificationSessionId(SESSION_ID),
234         request,
235         new CancellationSignal(),
236         textClassificationCallback);
237 
238     verify(textClassificationCallback).onFailure(Mockito.anyString());
239     verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.FAIL);
240   }
241 
verifyApiUsageLog( AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType, AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)242   private static void verifyApiUsageLog(
243       AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType,
244       AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)
245       throws Exception {
246     ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS);
247     ImmutableList<TextClassifierApiUsageReported> loggedEvents =
248         ImmutableList.copyOf(
249             loggedAtoms.stream()
250                 .map(Atom::getTextClassifierApiUsageReported)
251                 .collect(Collectors.toList()));
252     assertThat(loggedEvents).hasSize(1);
253     TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0);
254     assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L);
255     assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType);
256     assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType);
257     assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
258   }
259 
260   private static final class TestInjector implements DefaultTextClassifierService.Injector {
261     private final Context context;
262     private ModelFileManager modelFileManager;
263 
TestInjector(Context context, ModelFileManager modelFileManager)264     private TestInjector(Context context, ModelFileManager modelFileManager) {
265       this.context = Preconditions.checkNotNull(context);
266       this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
267     }
268 
269     @Override
getContext()270     public Context getContext() {
271       return context;
272     }
273 
274     @Override
createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)275     public ModelFileManager createModelFileManager(
276         TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) {
277       return modelFileManager;
278     }
279 
280     @Override
createTextClassifierSettings()281     public TextClassifierSettings createTextClassifierSettings() {
282       return new TextClassifierSettings();
283     }
284 
285     @Override
createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)286     public TextClassifierImpl createTextClassifierImpl(
287         TextClassifierSettings settings, ModelFileManager modelFileManager) {
288       return new TextClassifierImpl(context, settings, modelFileManager);
289     }
290 
291     @Override
createNormPriorityExecutor()292     public ListeningExecutorService createNormPriorityExecutor() {
293       return MoreExecutors.newDirectExecutorService();
294     }
295 
296     @Override
createLowPriorityExecutor()297     public ListeningExecutorService createLowPriorityExecutor() {
298       return MoreExecutors.newDirectExecutorService();
299     }
300 
301     @Override
createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)302     public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger(
303         TextClassifierSettings settings, Executor executor) {
304       return new TextClassifierApiUsageLogger(
305           /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor());
306     }
307   }
308 }
309