1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.mockito.Mockito.any; 21 import static org.mockito.Mockito.eq; 22 import static org.mockito.Mockito.verify; 23 import static org.mockito.Mockito.when; 24 25 import android.content.Context; 26 import android.os.CancellationSignal; 27 import android.service.textclassifier.TextClassifierService; 28 import android.view.textclassifier.ConversationAction; 29 import android.view.textclassifier.ConversationActions; 30 import android.view.textclassifier.TextClassification; 31 import android.view.textclassifier.TextClassifier; 32 import android.view.textclassifier.TextLanguage; 33 import android.view.textclassifier.TextLinks; 34 import android.view.textclassifier.TextLinks.TextLink; 35 import android.view.textclassifier.TextSelection; 36 import androidx.test.core.app.ApplicationProvider; 37 import androidx.test.ext.junit.runners.AndroidJUnit4; 38 import androidx.test.filters.SmallTest; 39 import com.android.internal.os.StatsdConfigProto.StatsdConfig; 40 import com.android.os.AtomsProto; 41 import com.android.os.AtomsProto.Atom; 42 import com.android.os.AtomsProto.TextClassifierApiUsageReported; 43 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType; 44 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType; 45 import com.android.textclassifier.common.ModelType; 46 import com.android.textclassifier.common.TextClassifierSettings; 47 import com.android.textclassifier.common.statsd.StatsdTestUtils; 48 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger; 49 import com.android.textclassifier.downloader.ModelDownloadManager; 50 import com.google.common.base.Preconditions; 51 import com.google.common.collect.ImmutableList; 52 import com.google.common.util.concurrent.ListeningExecutorService; 53 import com.google.common.util.concurrent.MoreExecutors; 54 import java.io.IOException; 55 import java.util.List; 56 import java.util.concurrent.Executor; 57 import java.util.stream.Collectors; 58 import org.junit.After; 59 import org.junit.Before; 60 import org.junit.Rule; 61 import org.junit.Test; 62 import org.junit.runner.RunWith; 63 import org.mockito.ArgumentCaptor; 64 import org.mockito.Mock; 65 import org.mockito.Mockito; 66 import org.mockito.junit.MockitoJUnit; 67 import org.mockito.junit.MockitoRule; 68 69 @SmallTest 70 @RunWith(AndroidJUnit4.class) 71 public class DefaultTextClassifierServiceTest { 72 73 @Rule public final MockitoRule mocks = MockitoJUnit.rule(); 74 75 /** A statsd config ID, which is arbitrary. */ 76 private static final long CONFIG_ID = 689777; 77 78 private static final long SHORT_TIMEOUT_MS = 1000; 79 80 private static final String SESSION_ID = "abcdef"; 81 82 private TestInjector testInjector; 83 private DefaultTextClassifierService defaultTextClassifierService; 84 @Mock private TextClassifierService.Callback<TextClassification> textClassificationCallback; 85 @Mock private TextClassifierService.Callback<TextSelection> textSelectionCallback; 86 @Mock private TextClassifierService.Callback<TextLinks> textLinksCallback; 87 @Mock private TextClassifierService.Callback<ConversationActions> conversationActionsCallback; 88 @Mock private TextClassifierService.Callback<TextLanguage> textLanguageCallback; 89 @Mock private ModelFileManager testModelFileManager; 90 91 @Before setup()92 public void setup() throws IOException { 93 testInjector = 94 new TestInjector(ApplicationProvider.getApplicationContext(), testModelFileManager); 95 defaultTextClassifierService = new DefaultTextClassifierService(testInjector); 96 defaultTextClassifierService.onCreate(); 97 98 when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 99 .thenReturn(TestDataUtils.getTestAnnotatorModelFileWrapped()); 100 when(testModelFileManager.findBestModelFile(eq(ModelType.LANG_ID), any(), any())) 101 .thenReturn(TestDataUtils.getLangIdModelFileWrapped()); 102 when(testModelFileManager.findBestModelFile(eq(ModelType.ACTIONS_SUGGESTIONS), any(), any())) 103 .thenReturn(TestDataUtils.getTestActionsModelFileWrapped()); 104 } 105 106 @Before setupStatsdTestUtils()107 public void setupStatsdTestUtils() throws Exception { 108 StatsdTestUtils.cleanup(CONFIG_ID); 109 110 StatsdConfig.Builder builder = 111 StatsdConfig.newBuilder() 112 .setId(CONFIG_ID) 113 .addAllowedLogSource(ApplicationProvider.getApplicationContext().getPackageName()); 114 StatsdTestUtils.addAtomMatcher(builder, Atom.TEXT_CLASSIFIER_API_USAGE_REPORTED_FIELD_NUMBER); 115 StatsdTestUtils.pushConfig(builder.build()); 116 } 117 118 @After tearDown()119 public void tearDown() throws Exception { 120 StatsdTestUtils.cleanup(CONFIG_ID); 121 } 122 123 @Test classifyText_success()124 public void classifyText_success() throws Exception { 125 String text = "www.android.com"; 126 TextClassification.Request request = 127 new TextClassification.Request.Builder(text, 0, text.length()).build(); 128 129 defaultTextClassifierService.onClassifyText( 130 TestingUtils.createTextClassificationSessionId(SESSION_ID), 131 request, 132 new CancellationSignal(), 133 textClassificationCallback); 134 135 ArgumentCaptor<TextClassification> captor = ArgumentCaptor.forClass(TextClassification.class); 136 verify(textClassificationCallback).onSuccess(captor.capture()); 137 assertThat(captor.getValue().getEntityCount()).isGreaterThan(0); 138 assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); 139 verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.SUCCESS); 140 } 141 142 @Test suggestSelection_success()143 public void suggestSelection_success() throws Exception { 144 String text = "Visit http://www.android.com for more information"; 145 String selected = "http"; 146 String suggested = "http://www.android.com"; 147 int start = text.indexOf(selected); 148 int end = start + suggested.length(); 149 TextSelection.Request request = new TextSelection.Request.Builder(text, start, end).build(); 150 151 defaultTextClassifierService.onSuggestSelection( 152 TestingUtils.createTextClassificationSessionId(SESSION_ID), 153 request, 154 new CancellationSignal(), 155 textSelectionCallback); 156 157 ArgumentCaptor<TextSelection> captor = ArgumentCaptor.forClass(TextSelection.class); 158 verify(textSelectionCallback).onSuccess(captor.capture()); 159 assertThat(captor.getValue().getEntityCount()).isGreaterThan(0); 160 assertThat(captor.getValue().getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); 161 verifyApiUsageLog(ApiType.SUGGEST_SELECTION, ResultType.SUCCESS); 162 } 163 164 @Test generateLinks_success()165 public void generateLinks_success() throws Exception { 166 String text = "Visit http://www.android.com for more information"; 167 TextLinks.Request request = new TextLinks.Request.Builder(text).build(); 168 169 defaultTextClassifierService.onGenerateLinks( 170 TestingUtils.createTextClassificationSessionId(SESSION_ID), 171 request, 172 new CancellationSignal(), 173 textLinksCallback); 174 175 ArgumentCaptor<TextLinks> captor = ArgumentCaptor.forClass(TextLinks.class); 176 verify(textLinksCallback).onSuccess(captor.capture()); 177 assertThat(captor.getValue().getLinks()).hasSize(1); 178 TextLink textLink = captor.getValue().getLinks().iterator().next(); 179 assertThat(textLink.getEntityCount()).isGreaterThan(0); 180 assertThat(textLink.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL); 181 verifyApiUsageLog(ApiType.GENERATE_LINKS, ResultType.SUCCESS); 182 } 183 184 @Test detectLanguage_success()185 public void detectLanguage_success() throws Exception { 186 String text = "ピカチュウ"; 187 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 188 189 defaultTextClassifierService.onDetectLanguage( 190 TestingUtils.createTextClassificationSessionId(SESSION_ID), 191 request, 192 new CancellationSignal(), 193 textLanguageCallback); 194 195 ArgumentCaptor<TextLanguage> captor = ArgumentCaptor.forClass(TextLanguage.class); 196 verify(textLanguageCallback).onSuccess(captor.capture()); 197 assertThat(captor.getValue().getLocaleHypothesisCount()).isGreaterThan(0); 198 assertThat(captor.getValue().getLocale(0).toLanguageTag()).isEqualTo("ja"); 199 verifyApiUsageLog(ApiType.DETECT_LANGUAGES, ResultType.SUCCESS); 200 } 201 202 @Test suggestConversationActions_success()203 public void suggestConversationActions_success() throws Exception { 204 ConversationActions.Message message = 205 new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS) 206 .setText("Checkout www.android.com") 207 .build(); 208 ConversationActions.Request request = 209 new ConversationActions.Request.Builder(ImmutableList.of(message)).build(); 210 211 defaultTextClassifierService.onSuggestConversationActions( 212 TestingUtils.createTextClassificationSessionId(SESSION_ID), 213 request, 214 new CancellationSignal(), 215 conversationActionsCallback); 216 217 ArgumentCaptor<ConversationActions> captor = ArgumentCaptor.forClass(ConversationActions.class); 218 verify(conversationActionsCallback).onSuccess(captor.capture()); 219 List<ConversationAction> conversationActions = captor.getValue().getConversationActions(); 220 assertThat(conversationActions.size()).isGreaterThan(0); 221 assertThat(conversationActions.get(0).getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL); 222 verifyApiUsageLog(ApiType.SUGGEST_CONVERSATION_ACTIONS, ResultType.SUCCESS); 223 } 224 225 @Test missingModelFile_onFailureShouldBeCalled()226 public void missingModelFile_onFailureShouldBeCalled() throws Exception { 227 when(testModelFileManager.findBestModelFile(eq(ModelType.ANNOTATOR), any(), any())) 228 .thenReturn(null); 229 defaultTextClassifierService.onCreate(); 230 231 TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build(); 232 defaultTextClassifierService.onClassifyText( 233 TestingUtils.createTextClassificationSessionId(SESSION_ID), 234 request, 235 new CancellationSignal(), 236 textClassificationCallback); 237 238 verify(textClassificationCallback).onFailure(Mockito.anyString()); 239 verifyApiUsageLog(ApiType.CLASSIFY_TEXT, ResultType.FAIL); 240 } 241 verifyApiUsageLog( AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType, AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType)242 private static void verifyApiUsageLog( 243 AtomsProto.TextClassifierApiUsageReported.ApiType expectedApiType, 244 AtomsProto.TextClassifierApiUsageReported.ResultType expectedResultApiType) 245 throws Exception { 246 ImmutableList<Atom> loggedAtoms = StatsdTestUtils.getLoggedAtoms(CONFIG_ID, SHORT_TIMEOUT_MS); 247 ImmutableList<TextClassifierApiUsageReported> loggedEvents = 248 ImmutableList.copyOf( 249 loggedAtoms.stream() 250 .map(Atom::getTextClassifierApiUsageReported) 251 .collect(Collectors.toList())); 252 assertThat(loggedEvents).hasSize(1); 253 TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0); 254 assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L); 255 assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType); 256 assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType); 257 assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID); 258 } 259 260 private static final class TestInjector implements DefaultTextClassifierService.Injector { 261 private final Context context; 262 private ModelFileManager modelFileManager; 263 TestInjector(Context context, ModelFileManager modelFileManager)264 private TestInjector(Context context, ModelFileManager modelFileManager) { 265 this.context = Preconditions.checkNotNull(context); 266 this.modelFileManager = Preconditions.checkNotNull(modelFileManager); 267 } 268 269 @Override getContext()270 public Context getContext() { 271 return context; 272 } 273 274 @Override createModelFileManager( TextClassifierSettings settings, ModelDownloadManager modelDownloadManager)275 public ModelFileManager createModelFileManager( 276 TextClassifierSettings settings, ModelDownloadManager modelDownloadManager) { 277 return modelFileManager; 278 } 279 280 @Override createTextClassifierSettings()281 public TextClassifierSettings createTextClassifierSettings() { 282 return new TextClassifierSettings(); 283 } 284 285 @Override createTextClassifierImpl( TextClassifierSettings settings, ModelFileManager modelFileManager)286 public TextClassifierImpl createTextClassifierImpl( 287 TextClassifierSettings settings, ModelFileManager modelFileManager) { 288 return new TextClassifierImpl(context, settings, modelFileManager); 289 } 290 291 @Override createNormPriorityExecutor()292 public ListeningExecutorService createNormPriorityExecutor() { 293 return MoreExecutors.newDirectExecutorService(); 294 } 295 296 @Override createLowPriorityExecutor()297 public ListeningExecutorService createLowPriorityExecutor() { 298 return MoreExecutors.newDirectExecutorService(); 299 } 300 301 @Override createTextClassifierApiUsageLogger( TextClassifierSettings settings, Executor executor)302 public TextClassifierApiUsageLogger createTextClassifierApiUsageLogger( 303 TextClassifierSettings settings, Executor executor) { 304 return new TextClassifierApiUsageLogger( 305 /* sampleRateSupplier= */ () -> 1, MoreExecutors.directExecutor()); 306 } 307 } 308 } 309