• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier;
18 
19 import static com.android.textclassifier.common.ModelFile.LANGUAGE_INDEPENDENT;
20 import static com.google.common.truth.Truth.assertThat;
21 import static org.mockito.Mockito.when;
22 
23 import android.content.Context;
24 import android.os.LocaleList;
25 import androidx.test.core.app.ApplicationProvider;
26 import androidx.test.ext.junit.runners.AndroidJUnit4;
27 import androidx.test.filters.SmallTest;
28 import androidx.work.WorkManager;
29 import com.android.textclassifier.ModelFileManagerImpl.DownloaderModelsLister;
30 import com.android.textclassifier.ModelFileManagerImpl.RegularFileFullMatchLister;
31 import com.android.textclassifier.ModelFileManagerImpl.RegularFilePatternMatchLister;
32 import com.android.textclassifier.common.ModelFile;
33 import com.android.textclassifier.common.ModelType;
34 import com.android.textclassifier.common.ModelType.ModelTypeDef;
35 import com.android.textclassifier.common.TextClassifierSettings;
36 import com.android.textclassifier.downloader.DownloadedModelManager;
37 import com.android.textclassifier.downloader.ModelDownloadManager;
38 import com.android.textclassifier.downloader.ModelDownloadWorker;
39 import com.android.textclassifier.testing.SetDefaultLocalesRule;
40 import com.android.textclassifier.testing.TestingDeviceConfig;
41 import com.google.common.collect.ImmutableList;
42 import com.google.common.io.Files;
43 import com.google.common.util.concurrent.MoreExecutors;
44 import java.io.File;
45 import java.io.IOException;
46 import java.util.ArrayList;
47 import java.util.Arrays;
48 import java.util.List;
49 import java.util.Locale;
50 import java.util.stream.Collectors;
51 import org.junit.After;
52 import org.junit.Before;
53 import org.junit.Rule;
54 import org.junit.Test;
55 import org.junit.runner.RunWith;
56 import org.mockito.Mock;
57 import org.mockito.junit.MockitoJUnit;
58 import org.mockito.junit.MockitoRule;
59 
60 @SmallTest
61 @RunWith(AndroidJUnit4.class)
62 public final class ModelFileManagerImplTest {
63   private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
64 
65   @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
66 
67   private TestingDeviceConfig deviceConfig;
68 
69   @Mock private DownloadedModelManager downloadedModelManager;
70 
71   @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
72   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
73 
74   private File rootTestDir;
75   private ModelFileManagerImpl modelFileManager;
76   private ModelDownloadManager modelDownloadManager;
77   private TextClassifierSettings settings;
78 
79   @Before
setup()80   public void setup() {
81     deviceConfig = new TestingDeviceConfig();
82     rootTestDir =
83         new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
84     rootTestDir.mkdirs();
85     Context context = ApplicationProvider.getApplicationContext();
86     settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false);
87     modelDownloadManager =
88         new ModelDownloadManager(
89             context,
90             ModelDownloadWorker.class,
91             () -> WorkManager.getInstance(context),
92             downloadedModelManager,
93             settings,
94             MoreExecutors.newDirectExecutorService());
95     modelFileManager = new ModelFileManagerImpl(context, modelDownloadManager, settings);
96     setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE));
97   }
98 
99   @After
removeTestDir()100   public void removeTestDir() {
101     recursiveDelete(rootTestDir);
102   }
103 
104   @Test
annotatorModelPreloaded()105   public void annotatorModelPreloaded() {
106     verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
107   }
108 
109   @Test
actionsModelPreloaded()110   public void actionsModelPreloaded() {
111     verifyModelPreloadedAsAsset(
112         ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
113   }
114 
115   @Test
langIdModelPreloaded()116   public void langIdModelPreloaded() {
117     verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
118   }
119 
verifyModelPreloadedAsAsset( @odelTypeDef String modelType, String expectedModelPath)120   private void verifyModelPreloadedAsAsset(
121       @ModelTypeDef String modelType, String expectedModelPath) {
122     List<ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
123     List<ModelFile> assetFiles =
124         modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
125 
126     assertThat(assetFiles).hasSize(1);
127     assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
128   }
129 
130   @Test
findBestModel_versionCode()131   public void findBestModel_versionCode() {
132     ModelFile olderModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
133     ModelFile newerModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2);
134     ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile);
135 
136     ModelFile bestModelFile =
137         modelFileManager.findBestModelFile(
138             MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null);
139     assertThat(bestModelFile).isEqualTo(newerModelFile);
140   }
141 
142   @Test
findBestModel_languageDependentModelIsPreferred()143   public void findBestModel_languageDependentModelIsPreferred() {
144     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
145     ModelFile languageDependentModelFile =
146         createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
147     ModelFileManager modelFileManager =
148         createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
149 
150     ModelFile bestModelFile =
151         modelFileManager.findBestModelFile(
152             MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null);
153     assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
154   }
155 
156   @Test
findBestModel_noMatchedLanguageModel()157   public void findBestModel_noMatchedLanguageModel() {
158     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
159     ModelFile languageDependentModelFile = createModelFile("zh-hk", /* version */ 1);
160     ModelFileManager modelFileManager =
161         createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
162 
163     ModelFile bestModelFile =
164         modelFileManager.findBestModelFile(
165             MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null);
166     assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
167   }
168 
169   @Test
findBestModel_languageIsMoreImportantThanVersion()170   public void findBestModel_languageIsMoreImportantThanVersion() {
171     ModelFile matchButOlderModel = createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
172     ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2);
173     ModelFileManager modelFileManager =
174         createModelFileManager(matchButOlderModel, mismatchButNewerModel);
175 
176     ModelFile bestModelFile =
177         modelFileManager.findBestModelFile(
178             MODEL_TYPE, new LocaleList(DEFAULT_LOCALE), /* detectedLocales= */ null);
179     assertThat(bestModelFile).isEqualTo(matchButOlderModel);
180   }
181 
182   @Test
findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage()183   public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() {
184     setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh"));
185     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
186     ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
187     ModelFileManager modelFileManager =
188         createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
189 
190     ModelFile bestModelFile =
191         modelFileManager.findBestModelFile(
192             MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"), /* detectedLocales= */ null);
193     assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
194   }
195 
196   @Test
findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match()197   public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() {
198     setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk"));
199     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
200     ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
201     ModelFileManager modelFileManager =
202         createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
203 
204     ModelFile bestModelFile =
205         modelFileManager.findBestModelFile(
206             MODEL_TYPE, LocaleList.forLanguageTags("zh"), /* detectedLocales= */ null);
207     assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
208   }
209 
210   @Test
findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch()211   public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() {
212     setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
213     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
214     ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
215     ModelFileManager modelFileManager =
216         createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
217 
218     ModelFile bestModelFile =
219         modelFileManager.findBestModelFile(
220             MODEL_TYPE, LocaleList.forLanguageTags("zh"), /* detectedLocales= */ null);
221     assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
222   }
223 
224   @Test
findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided()225   public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() {
226     setDefaultLocalesRule.set(
227         new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
228     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
229     ModelFile nonPrimaryLocaleModelFile = createModelFile("zh-hk", /* version */ 1);
230     ModelFileManager modelFileManager =
231         createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile);
232 
233     ModelFile bestModelFile =
234         modelFileManager.findBestModelFile(
235             MODEL_TYPE, /* localePreferences= */ null, /* detectedLocales= */ null);
236     assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
237   }
238 
239   @Test
findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided()240   public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() {
241     setDefaultLocalesRule.set(
242         new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
243 
244     ModelFile languageIndependentModelFile = createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
245     ModelFile nonPrimaryLocalePreferenceModelFile = createModelFile("zh-hk", /* version */ 1);
246     ModelFileManager modelFileManager =
247         createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile);
248 
249     ModelFile bestModelFile =
250         modelFileManager.findBestModelFile(
251             MODEL_TYPE,
252             new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")),
253             /* detectedLocales= */ null);
254     assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
255   }
256 
257   @Test
findBestModel_multiLanguageEnabled_noMatchedModel()258   public void findBestModel_multiLanguageEnabled_noMatchedModel() {
259     setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
260     deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
261 
262     ModelFile primaryLocalePreferenceModelFile = createModelFile("en", /* version= */ 1);
263     ModelFile secondaryLocalePreferencetModelFile = createModelFile("zh-hk", /* version */ 1);
264     ModelFileManager modelFileManager =
265         createModelFileManager(
266             primaryLocalePreferenceModelFile, secondaryLocalePreferencetModelFile);
267     final LocaleList requestLocalePreferences =
268         new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("fy"));
269     final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("hr");
270 
271     ModelFile bestModelFile =
272         modelFileManager.findBestModelFile(
273             MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
274     assertThat(bestModelFile).isEqualTo(primaryLocalePreferenceModelFile);
275   }
276 
277   @Test
findBestModel_multiLanguageEnabled_matchDetected()278   public void findBestModel_multiLanguageEnabled_matchDetected() {
279     setDefaultLocalesRule.set(
280         new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
281     deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, true);
282 
283     ModelFile localePreferenceModelFile = createModelFile("zh", /*version*/ 1);
284     ModelFileManager modelFileManager = createModelFileManager(localePreferenceModelFile);
285     final LocaleList requestLocalePreferences =
286         new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("zh"));
287     final LocaleList detectedLocalePreferences = LocaleList.forLanguageTags("zh");
288 
289     ModelFile bestModelFile =
290         modelFileManager.findBestModelFile(
291             MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
292     assertThat(bestModelFile).isEqualTo(localePreferenceModelFile);
293   }
294 
295   @Test
findBestModel_multiLanguageDisabled_matchDetected()296   public void findBestModel_multiLanguageDisabled_matchDetected() {
297     setDefaultLocalesRule.set(
298         new LocaleList(Locale.forLanguageTag("en-GB"), Locale.forLanguageTag("zh-hk")));
299     deviceConfig.setConfig(TextClassifierSettings.MULTI_LANGUAGE_SUPPORT_ENABLED, false);
300 
301     ModelFile nonLocalePreferenceModelFile = createModelFile("zh", /*version*/ 1);
302     ModelFileManager modelFileManager = createModelFileManager(nonLocalePreferenceModelFile);
303     final LocaleList requestLocalePreferences = new LocaleList(Locale.forLanguageTag("en"));
304     final LocaleList detectedLocalePreferences = LocaleList.getEmptyLocaleList();
305 
306     ModelFile bestModelFile =
307         modelFileManager.findBestModelFile(
308             MODEL_TYPE, requestLocalePreferences, detectedLocalePreferences);
309     assertThat(bestModelFile).isEqualTo(null);
310   }
311 
312   @Test
downloaderModelsLister()313   public void downloaderModelsLister() throws IOException {
314     File annotatorFile = new File(rootTestDir, "annotator.model");
315     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
316     File langIdFile = new File(rootTestDir, "langId.model");
317     Files.copy(TestDataUtils.getLangIdModelFile(), langIdFile);
318 
319     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
320 
321     DownloaderModelsLister downloaderModelsLister =
322         new DownloaderModelsLister(modelDownloadManager, settings);
323 
324     when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
325     when(downloadedModelManager.listModels(ModelType.LANG_ID))
326         .thenReturn(Arrays.asList(langIdFile));
327     when(downloadedModelManager.listModels(ModelType.ACTIONS_SUGGESTIONS))
328         .thenReturn(new ArrayList<>());
329     assertThat(downloaderModelsLister.list(MODEL_TYPE))
330         .containsExactly(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
331     assertThat(downloaderModelsLister.list(ModelType.LANG_ID))
332         .containsExactly(ModelFile.createFromRegularFile(langIdFile, ModelType.LANG_ID));
333     assertThat(downloaderModelsLister.list(ModelType.ACTIONS_SUGGESTIONS)).isEmpty();
334   }
335 
336   @Test
downloaderModelsLister_checkModelFileManager()337   public void downloaderModelsLister_checkModelFileManager() throws IOException {
338     File annotatorFile = new File(rootTestDir, "test.model");
339     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
340 
341     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
342     when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
343     assertThat(modelFileManager.listModelFiles(MODEL_TYPE))
344         .contains(ModelFile.createFromRegularFile(annotatorFile, MODEL_TYPE));
345   }
346 
347   @Test
downloaderModelsLister_disabled()348   public void downloaderModelsLister_disabled() throws IOException {
349     File annotatorFile = new File(rootTestDir, "test.model");
350     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), annotatorFile);
351 
352     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
353     DownloaderModelsLister downloaderModelsLister =
354         new DownloaderModelsLister(modelDownloadManager, settings);
355     when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(Arrays.asList(annotatorFile));
356     assertThat(downloaderModelsLister.list(MODEL_TYPE)).isEmpty();
357   }
358 
359   @Test
regularFileFullMatchLister()360   public void regularFileFullMatchLister() throws IOException {
361     File modelFile = new File(rootTestDir, "test.model");
362     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
363     File wrongFile = new File(rootTestDir, "wrong.model");
364     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
365 
366     RegularFileFullMatchLister regularFileFullMatchLister =
367         new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
368     ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
369 
370     assertThat(listedModels).hasSize(1);
371     assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
372     assertThat(listedModels.get(0).isAsset).isFalse();
373   }
374 
375   @Test
regularFilePatternMatchLister()376   public void regularFilePatternMatchLister() throws IOException {
377     File modelFile1 = new File(rootTestDir, "annotator.en.model");
378     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
379     File modelFile2 = new File(rootTestDir, "annotator.fr.model");
380     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
381     File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
382     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
383 
384     RegularFilePatternMatchLister regularFilePatternMatchLister =
385         new RegularFilePatternMatchLister(
386             MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
387     ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
388 
389     assertThat(listedModels).hasSize(2);
390     assertThat(listedModels.get(0).isAsset).isFalse();
391     assertThat(listedModels.get(1).isAsset).isFalse();
392     assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
393         .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
394   }
395 
396   @Test
regularFilePatternMatchLister_disabled()397   public void regularFilePatternMatchLister_disabled() throws IOException {
398     File modelFile1 = new File(rootTestDir, "annotator.en.model");
399     Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
400 
401     RegularFilePatternMatchLister regularFilePatternMatchLister =
402         new RegularFilePatternMatchLister(
403             MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
404     ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
405 
406     assertThat(listedModels).isEmpty();
407   }
408 
createModelFileManager(ModelFile... modelFiles)409   private ModelFileManager createModelFileManager(ModelFile... modelFiles) {
410     return new ModelFileManagerImpl(
411         ApplicationProvider.getApplicationContext(),
412         ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)),
413         settings);
414   }
415 
createModelFile(String supportedLocaleTags, int version)416   private ModelFile createModelFile(String supportedLocaleTags, int version) {
417     return new ModelFile(
418         MODEL_TYPE,
419         new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version))
420             .getAbsolutePath(),
421         version,
422         supportedLocaleTags,
423         /* isAsset= */ false);
424   }
425 
recursiveDelete(File f)426   private static void recursiveDelete(File f) {
427     if (f.isDirectory()) {
428       for (File innerFile : f.listFiles()) {
429         recursiveDelete(innerFile);
430       }
431     }
432     f.delete();
433   }
434 }
435