1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier; 18 19 import static com.android.textclassifier.common.ModelFile.LANGUAGE_INDEPENDENT; 20 import static com.google.common.truth.Truth.assertThat; 21 import static org.mockito.Mockito.when; 22 23 import android.content.Context; 24 import android.os.LocaleList; 25 import androidx.test.core.app.ApplicationProvider; 26 import androidx.test.ext.junit.runners.AndroidJUnit4; 27 import androidx.test.filters.SmallTest; 28 import androidx.work.WorkManager; 29 import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister; 30 import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister; 31 import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister; 32 import com.android.textclassifier.common.ModelFile; 33 import com.android.textclassifier.common.ModelType; 34 import com.android.textclassifier.common.ModelType.ModelTypeDef; 35 import com.android.textclassifier.common.TextClassifierSettings; 36 import com.android.textclassifier.downloader.DownloadedModelManager; 37 import com.android.textclassifier.downloader.ModelDownloadManager; 38 import com.android.textclassifier.downloader.ModelDownloadWorker; 39 import com.android.textclassifier.testing.SetDefaultLocalesRule; 40 import com.android.textclassifier.testing.TestingDeviceConfig; 41 import com.google.common.collect.ImmutableList; 42 import com.google.common.io.Files; 43 import com.google.common.util.concurrent.MoreExecutors; 44 import java.io.File; 45 import java.io.IOException; 46 import java.util.ArrayList; 47 import java.util.Arrays; 48 import java.util.List; 49 import java.util.Locale; 50 import java.util.stream.Collectors; 51 import org.junit.After; 52 import org.junit.Before; 53 import org.junit.Rule; 54 import org.junit.Test; 55 import org.junit.runner.RunWith; 56 import org.mockito.Mock; 57 import org.mockito.junit.MockitoJUnit; 58 import org.mockito.junit.MockitoRule; 59 60 @SmallTest 61 @RunWith(AndroidJUnit4.class) 62 public final class ModelFileManagerImplTest { 63 private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US"); 64 65 @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR; 66 67 private TestingDeviceConfig deviceConfig; 68 69 @Mock private DownloadedModelManager downloadedModelManager; 70 71 @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule(); 72 @Rule public final MockitoRule mocks = MockitoJUnit.rule(); 73 74 private File rootTestDir; 75 private ModelFileManagerImpl modelFileManager; 76 private ModelDownloadManager modelDownloadManager; 77 private TextClassifierSettings settings; 78 79 @Before setup()80 public void setup() { 81 deviceConfig = new TestingDeviceConfig(); 82 rootTestDir = 83 new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir"); 84 rootTestDir.mkdirs(); 85 Context context = ApplicationProvider.getApplicationContext(); 86 settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false); 87 modelDownloadManager = 88 new ModelDownloadManager( 89 context, 90 ModelDownloadWorker.class, 91 () -> WorkManager.getInstance(context), 92 downloadedModelManager, 93 settings, 94 MoreExecutors.newDirectExecutorService()); 95 modelFileManager = new ModelFileManagerImpl(context, modelDownloadManager, settings); 96 setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE)); 97 } 98 99 @After removeTestDir()100 public void removeTestDir() { 101 recursiveDelete(rootTestDir); 102 } 103 104 @Test annotatorModelPreloaded()105 public void annotatorModelPreloaded() { 106 verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model"); 107 } 108 109 @Test actionsModelPreloaded()110 public void actionsModelPreloaded() { 111 verifyModelPreloadedAsAsset( 112 ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model"); 113 } 114 115 @Test langIdModelPreloaded()116 public void langIdModelPreloaded() { 117 verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model"); 118 } 119 verifyModelPreloadedAsAsset( @odelTypeDef String modelType, String expectedModelPath)120 private void verifyModelPreloadedAsAsset( 121 @ModelTypeDef String modelType, String expectedModelPath) { 122 List<ModelFile> modelFiles = modelFileManager.listModelFiles(modelType); 123 List<ModelFile> assetFiles = 124 modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList()); 125 126 assertThat(assetFiles).hasSize(1); 127 assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath); 128 } 129 130 @Test findBestModel_versionCode()131 public void findBestModel_versionCode() { 132 ModelFile olderModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 133 ModelFile newerModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2); 134 ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile); 135 136 ModelFile bestModelFile = 137 modelFileManager.findBestModelFile( 138 MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null); 139 assertThat(bestModelFile).isEqualTo(newerModelFile); 140 } 141 142 @Test findBestModel_languageDependentModelIsPreferred()143 public void findBestModel_languageDependentModelIsPreferred() { 144 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 145 ModelFile languageDependentModelFile = 146 createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1); 147 ModelFileManager modelFileManager = 148 createModelFileManager(languageIndependentModelFile, languageDependentModelFile); 149 150 ModelFile bestModelFile = 151 modelFileManager.findBestModelFile( 152 MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null); 153 assertThat(bestModelFile).isEqualTo(languageDependentModelFile); 154 } 155 156 @Test findBestModel_noMatchedLanguageModel()157 public void findBestModel_noMatchedLanguageModel() { 158 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 159 ModelFile languageDependentModelFile = createModelFile("zh-hk", /* version */ 1); 160 ModelFileManager modelFileManager = 161 createModelFileManager(languageIndependentModelFile, languageDependentModelFile); 162 163 ModelFile bestModelFile = 164 modelFileManager.findBestModelFile( 165 MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null); 166 assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); 167 } 168 169 @Test findBestModel_languageIsMoreImportantThanVersion()170 public void findBestModel_languageIsMoreImportantThanVersion() { 171 ModelFile matchButOlderModel = createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1); 172 ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2); 173 ModelFileManager modelFileManager = 174 createModelFileManager(matchButOlderModel, mismatchButNewerModel); 175 176 ModelFile bestModelFile = 177 modelFileManager.findBestModelFile( 178 MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null); 179 assertThat(bestModelFile).isEqualTo(matchButOlderModel); 180 } 181 182 @Test findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage()183 public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() { 184 setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh")); 185 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 186 ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1); 187 ModelFileManager modelFileManager = 188 createModelFileManager(languageIndependentModelFile, languageDependentModelFile); 189 190 ModelFile bestModelFile = 191 modelFileManager.findBestModelFile( 192 MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"), /* detectedLocales= */ null); 193 assertThat(bestModelFile).isEqualTo(languageDependentModelFile); 194 } 195 196 @Test findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match()197 public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() { 198 setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk")); 199 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 200 ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1); 201 ModelFileManager modelFileManager = 202 createModelFileManager(languageIndependentModelFile, languageDependentModelFile); 203 204 ModelFile bestModelFile = 205 modelFileManager.findBestModelFile( 206 MODEL_TYPE, LocaleList.forLanguageTags("zh"), /* detectedLocales= */ null); 207 assertThat(bestModelFile).isEqualTo(languageDependentModelFile); 208 } 209 210 @Test findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch()211 public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() { 212 setDefaultLocalesRule.set(LocaleList.forLanguageTags("en")); 213 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 214 ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1); 215 ModelFileManager modelFileManager = 216 createModelFileManager(languageIndependentModelFile, languageDependentModelFile); 217 218 ModelFile bestModelFile = 219 modelFileManager.findBestModelFile( 220 MODEL_TYPE, LocaleList.forLanguageTags("zh"), /* detectedLocales= */ null); 221 assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); 222 } 223 224 @Test findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided()225 public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() { 226 setDefaultLocalesRule.set( 227 new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk"))); 228 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 229 ModelFile nonPrimaryLocaleModelFile = createModelFile("zh-hk", /* version */ 1); 230 ModelFileManager modelFileManager = 231 createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile); 232 233 ModelFile bestModelFile = 234 modelFileManager.findBestModelFile( 235 MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null); 236 assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); 237 } 238 239 @Test findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided()240 public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() { 241 setDefaultLocalesRule.set( 242 new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk"))); 243 244 ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1); 245 ModelFile nonPrimaryLocalePreferenceModelFile = createModelFile("zh-hk", /* version */ 1); 246 ModelFileManager modelFileManager = 247 createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile); 248 249 ModelFile bestModelFile = 250 modelFileManager.findBestModelFile( 251 MODEL_TYPE, 252 new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")), 253 /* detectedLocales= */ null); 254 assertThat(bestModelFile).isEqualTo(languageIndependentModelFile); 255 } 256 257 @Test findBestModel_multiLanguageEnabled_noMatchedModel()258 public void findBestModel_multiLanguageEnabled_noMatchedModel() { 259 setDefaultLocalesRule.set(LocaleList.forLanguageTags("en")); 260 deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true); 261 262 ModelFile primaryLocalePreferenceModelFile = createModelFile("en", /* version= */ 1); 263 ModelFile secondaryLocalePreferencetModelFile = createModelFile("zh-hk", /* version */ 1); 264 ModelFileManager modelFileManager = 265 createModelFileManager( 266 primaryLocalePreferenceModelFile, secondaryLocalePreferencetModelFile); 267 final LocaleList requestLocalePreferences = 268 new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("fy")); 269 final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("hr"); 270 271 ModelFile bestModelFile = 272 modelFileManager.findBestModelFile( 273 MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences); 274 assertThat(bestModelFile).isEqualTo(primaryLocalePreferenceModelFile); 275 } 276 277 @Test findBestModel_multiLanguageEnabled_matchDetected()278 public void findBestModel_multiLanguageEnabled_matchDetected() { 279 setDefaultLocalesRule.set( 280 new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk"))); 281 deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true); 282 283 ModelFile localePreferenceModelFile = createModelFile("zh", /*version*/ 1); 284 ModelFileManager modelFileManager = createModelFileManager(localePreferenceModelFile); 285 final LocaleList requestLocalePreferences = 286 new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("zh")); 287 final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("zh"); 288 289 ModelFile bestModelFile = 290 modelFileManager.findBestModelFile( 291 MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences); 292 assertThat(bestModelFile).isEqualTo(localePreferenceModelFile); 293 } 294 295 @Test findBestModel_multiLanguageDisabled_matchDetected()296 public void findBestModel_multiLanguageDisabled_matchDetected() { 297 setDefaultLocalesRule.set( 298 new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk"))); 299 deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, false); 300 301 ModelFile nonLocalePreferenceModelFile = createModelFile("zh", /*version*/ 1); 302 ModelFileManager modelFileManager = createModelFileManager(nonLocalePreferenceModelFile); 303 final LocaleList requestLocalePreferences = new LocaleList(Locale.forLanguageTag("en")); 304 final LocaleList detectedLocalePreferences = LocaleList.getEmptyLocaleList(); 305 306 ModelFile bestModelFile = 307 modelFileManager.findBestModelFile( 308 MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences); 309 assertThat(bestModelFile).isEqualTo(null); 310 } 311 312 @Test downloaderModelsLister()313 public void downloaderModelsLister() throws IOException { 314 File annotatorFile = new File(rootTestDir, "annotator.model"); 315 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile); 316 File langIdFile = new File(rootTestDir, "langId.model"); 317 Files.copy(TestDataUtils.getLangIdModelFile(), langIdFile); 318 319 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 320 321 DownloaderModelsLister downloaderModelsLister = 322 new DownloaderModelsLister(modelDownloadManager, settings); 323 324 when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile)); 325 when(downloadedModelManager.listModels(ModelType.LANG_ID)) 326 .thenReturn(Arrays.asList(langIdFile)); 327 when(downloadedModelManager.listModels(ModelType.ACTIONS_SUGGESTIONS)) 328 .thenReturn(new ArrayList<>()); 329 assertThat(downloaderModelsLister.list(MODEL_TYPE)) 330 .containsExactly(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE)); 331 assertThat(downloaderModelsLister.list(ModelType.LANG_ID)) 332 .containsExactly(ModelFile.createFromRegularFile(langIdFile, ModelType.LANG_ID)); 333 assertThat(downloaderModelsLister.list(ModelType.ACTIONS_SUGGESTIONS)).isEmpty(); 334 } 335 336 @Test downloaderModelsLister_checkModelFileManager()337 public void downloaderModelsLister_checkModelFileManager() throws IOException { 338 File annotatorFile = new File(rootTestDir, "test.model"); 339 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile); 340 341 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 342 when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile)); 343 assertThat(modelFileManager.listModelFiles(MODEL_TYPE)) 344 .contains(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE)); 345 } 346 347 @Test downloaderModelsLister_disabled()348 public void downloaderModelsLister_disabled() throws IOException { 349 File annotatorFile = new File(rootTestDir, "test.model"); 350 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile); 351 352 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false); 353 DownloaderModelsLister downloaderModelsLister = 354 new DownloaderModelsLister(modelDownloadManager, settings); 355 when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile)); 356 assertThat(downloaderModelsLister.list(MODEL_TYPE)).isEmpty(); 357 } 358 359 @Test regularFileFullMatchLister()360 public void regularFileFullMatchLister() throws IOException { 361 File modelFile = new File(rootTestDir, "test.model"); 362 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile); 363 File wrongFile = new File(rootTestDir, "wrong.model"); 364 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile); 365 366 RegularFileFullMatchLister regularFileFullMatchLister = 367 new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true); 368 ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE); 369 370 assertThat(listedModels).hasSize(1); 371 assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath()); 372 assertThat(listedModels.get(0).isAsset).isFalse(); 373 } 374 375 @Test regularFilePatternMatchLister()376 public void regularFilePatternMatchLister() throws IOException { 377 File modelFile1 = new File(rootTestDir, "annotator.en.model"); 378 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1); 379 File modelFile2 = new File(rootTestDir, "annotator.fr.model"); 380 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2); 381 File mismatchedModelFile = new File(rootTestDir, "actions.en.model"); 382 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile); 383 384 RegularFilePatternMatchLister regularFilePatternMatchLister = 385 new RegularFilePatternMatchLister( 386 MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true); 387 ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE); 388 389 assertThat(listedModels).hasSize(2); 390 assertThat(listedModels.get(0).isAsset).isFalse(); 391 assertThat(listedModels.get(1).isAsset).isFalse(); 392 assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath)) 393 .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath()); 394 } 395 396 @Test regularFilePatternMatchLister_disabled()397 public void regularFilePatternMatchLister_disabled() throws IOException { 398 File modelFile1 = new File(rootTestDir, "annotator.en.model"); 399 Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1); 400 401 RegularFilePatternMatchLister regularFilePatternMatchLister = 402 new RegularFilePatternMatchLister( 403 MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false); 404 ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE); 405 406 assertThat(listedModels).isEmpty(); 407 } 408 createModelFileManager(ModelFile... modelFiles)409 private ModelFileManager createModelFileManager(ModelFile... modelFiles) { 410 return new ModelFileManagerImpl( 411 ApplicationProvider.getApplicationContext(), 412 ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)), 413 settings); 414 } 415 createModelFile(String supportedLocaleTags, int version)416 private ModelFile createModelFile(String supportedLocaleTags, int version) { 417 return new ModelFile( 418 MODEL_TYPE, 419 new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version)) 420 .getAbsolutePath(), 421 version, 422 supportedLocaleTags, 423 /* isAsset= */ false); 424 } 425 recursiveDelete(File f)426 private static void recursiveDelete(File f) { 427 if (f.isDirectory()) { 428 for (File innerFile : f.listFiles()) { 429 recursiveDelete(innerFile); 430 } 431 } 432 f.delete(); 433 } 434 } 435