• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 
21 import android.content.Context;
22 import androidx.room.Room;
23 import androidx.test.core.app.ApplicationProvider;
24 import com.android.textclassifier.common.ModelType;
25 import com.android.textclassifier.common.ModelType.ModelTypeDef;
26 import com.android.textclassifier.common.TextClassifierSettings;
27 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
28 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
31 import com.android.textclassifier.testing.TestingDeviceConfig;
32 import com.google.common.collect.ImmutableMap;
33 import java.io.File;
34 import org.junit.After;
35 import org.junit.Before;
36 import org.junit.Test;
37 import org.junit.runner.RunWith;
38 import org.junit.runners.JUnit4;
39 
40 @RunWith(JUnit4.class)
41 public final class DownloadedModelManagerImplTest {
42 
43   private File modelDownloaderDir;
44   private DownloadedModelDatabase db;
45   private DownloadedModelManagerImpl downloadedModelManagerImpl;
46   private TestingDeviceConfig deviceConfig;
47   private TextClassifierSettings settings;
48 
49   @Before
setUp()50   public void setUp() {
51     Context context = ApplicationProvider.getApplicationContext();
52     modelDownloaderDir = new File(context.getFilesDir(), "test_dir");
53     modelDownloaderDir.mkdirs();
54     deviceConfig = new TestingDeviceConfig();
55     settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false);
56     db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
57     downloadedModelManagerImpl =
58         DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
59   }
60 
61   @After
cleanUp()62   public void cleanUp() {
63     DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
64     db.close();
65   }
66 
67   @Test
getModelDownloaderDir()68   public void getModelDownloaderDir() throws Exception {
69     modelDownloaderDir.delete();
70     assertThat(downloadedModelManagerImpl.getModelDownloaderDir().exists()).isTrue();
71     assertThat(downloadedModelManagerImpl.getModelDownloaderDir()).isEqualTo(modelDownloaderDir);
72   }
73 
74   @Test
listModels_cacheNotInitialized()75   public void listModels_cacheNotInitialized() throws Exception {
76     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
77     registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
78 
79     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
80         .containsExactly(new File("modelPathEn"), new File("modelPathZh"));
81     assertThat(downloadedModelManagerImpl.listModels(ModelType.LANG_ID)).isEmpty();
82   }
83 
84   @Test
listModels_doNotListBlockedModels()85   public void listModels_doNotListBlockedModels() throws Exception {
86     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
87     registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
88     deviceConfig.setConfig(
89         TextClassifierSettings.MODEL_URL_BLOCKLIST,
90         String.format(
91             "%s%s%s",
92             "modelUrlEn", TextClassifierSettings.MODEL_URL_BLOCKLIST_SEPARATOR, "modelUrlXX"));
93 
94     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
95         .containsExactly(new File("modelPathZh"));
96   }
97 
98   @Test
listModels_cacheNotUpdatedUnlessOnDownloadCompleted()99   public void listModels_cacheNotUpdatedUnlessOnDownloadCompleted() throws Exception {
100     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
101     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
102         .containsExactly(new File("modelPathEn"));
103 
104     registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
105     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
106         .containsExactly(new File("modelPathEn"));
107 
108     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
109         ImmutableMap.of(
110             ModelType.ANNOTATOR,
111             ManifestsToDownloadByType.create(ImmutableMap.of("zh", "manifestUrlZh")));
112     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
113     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
114         .contains(new File("modelPathZh"));
115   }
116 
117   @Test
getModel()118   public void getModel() throws Exception {
119     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
120     assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
121         .isEqualTo("modelPath");
122     assertThat(downloadedModelManagerImpl.getModel("modelUrl2")).isNull();
123   }
124 
125   @Test
getManifest()126   public void getManifest() throws Exception {
127     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
128     assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
129     assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
130   }
131 
132   @Test
getManifestEnrollment()133   public void getManifestEnrollment() throws Exception {
134     registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
135     assertThat(
136             downloadedModelManagerImpl
137                 .getManifestEnrollment(ModelType.ANNOTATOR, "en")
138                 .getManifestUrl())
139         .isEqualTo("manifestUrl");
140     assertThat(downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "zh"))
141         .isNull();
142   }
143 
144   @Test
registerModel()145   public void registerModel() throws Exception {
146     downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
147 
148     assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
149         .isEqualTo("modelPath");
150   }
151 
152   @Test
registerManifest()153   public void registerManifest() throws Exception {
154     downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
155     downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
156 
157     assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
158   }
159 
160   @Test
registerManifestDownloadFailure()161   public void registerManifestDownloadFailure() throws Exception {
162     downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl");
163 
164     Manifest manifest = downloadedModelManagerImpl.getManifest("manifestUrl");
165     assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
166     assertThat(manifest.getFailureCounts()).isEqualTo(1);
167   }
168 
169   @Test
registerManifestEnrollment()170   public void registerManifestEnrollment() throws Exception {
171     downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
172     downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
173     downloadedModelManagerImpl.registerManifestEnrollment(ModelType.ANNOTATOR, "en", "manifestUrl");
174 
175     ManifestEnrollment manifestEnrollment =
176         downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "en");
177     assertThat(manifestEnrollment.getModelType()).isEqualTo(ModelType.ANNOTATOR);
178     assertThat(manifestEnrollment.getLocaleTag()).isEqualTo("en");
179     assertThat(manifestEnrollment.getManifestUrl()).isEqualTo("manifestUrl");
180   }
181 
182   @Test
onDownloadCompleted_newModelDownloaded()183   public void onDownloadCompleted_newModelDownloaded() throws Exception {
184     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
185         ImmutableMap.of(
186             ModelType.ANNOTATOR,
187             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
188     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
189     modelFile1.createNewFile();
190     registerManifestToDB(
191         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
192     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
193 
194     assertThat(modelFile1.exists()).isTrue();
195     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
196         .containsExactly(modelFile1);
197 
198     manifestsToDownload =
199         ImmutableMap.of(
200             ModelType.ANNOTATOR,
201             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2")));
202     File modelFile2 = new File(modelDownloaderDir, "modelFile2");
203     modelFile2.createNewFile();
204     registerManifestToDB(
205         ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
206     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
207 
208     assertThat(modelFile1.exists()).isFalse();
209     assertThat(modelFile2.exists()).isTrue();
210     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
211         .containsExactly(modelFile2);
212   }
213 
214   @Test
onDownloadCompleted_newModelDownloadFailed()215   public void onDownloadCompleted_newModelDownloadFailed() throws Exception {
216     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
217         ImmutableMap.of(
218             ModelType.ANNOTATOR,
219             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
220     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
221     modelFile1.createNewFile();
222     registerManifestToDB(
223         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
224     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
225 
226     assertThat(modelFile1.exists()).isTrue();
227     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
228         .containsExactly(modelFile1);
229 
230     manifestsToDownload =
231         ImmutableMap.of(
232             ModelType.ANNOTATOR,
233             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2")));
234     downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
235     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
236 
237     assertThat(modelFile1.exists()).isTrue();
238     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
239         .containsExactly(modelFile1);
240   }
241 
242   @Test
onDownloadCompleted_flatUnset()243   public void onDownloadCompleted_flatUnset() throws Exception {
244     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
245         ImmutableMap.of(
246             ModelType.ANNOTATOR,
247             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
248     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
249     modelFile1.createNewFile();
250     registerManifestToDB(
251         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
252     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
253 
254     assertThat(modelFile1.exists()).isTrue();
255     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
256         .containsExactly(modelFile1);
257 
258     manifestsToDownload = ImmutableMap.of();
259     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
260 
261     assertThat(modelFile1.exists()).isFalse();
262     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)).isEmpty();
263   }
264 
265   @Test
onDownloadCompleted_cleanUpFailureRecords()266   public void onDownloadCompleted_cleanUpFailureRecords() throws Exception {
267     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
268         ImmutableMap.of(
269             ModelType.ANNOTATOR,
270             ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1")));
271     downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl1");
272     downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
273     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
274 
275     assertThat(downloadedModelManagerImpl.getManifest("manifestUrl1").getStatus())
276         .isEqualTo(Manifest.STATUS_FAILED);
277     assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
278   }
279 
280   @Test
onDownloadCompleted_modelsForMultipleLocalesDownloaded()281   public void onDownloadCompleted_modelsForMultipleLocalesDownloaded() throws Exception {
282     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
283         ImmutableMap.of(
284             ModelType.ANNOTATOR,
285             ManifestsToDownloadByType.create(
286                 ImmutableMap.of("en", "manifestUrl1", "es", "manifestUrl2")));
287 
288     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
289     modelFile1.createNewFile();
290     registerManifestToDB(
291         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
292 
293     File modelFile2 = new File(modelDownloaderDir, "modelFile2");
294     modelFile2.createNewFile();
295     registerManifestToDB(
296         ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
297 
298     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
299     assertThat(modelFile1.exists()).isTrue();
300     assertThat(modelFile2.exists()).isTrue();
301     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
302         .containsExactly(modelFile1, modelFile2);
303   }
304 
305   @Test
onDownloadCompleted_multipleLocales_oneDownloadFailed()306   public void onDownloadCompleted_multipleLocales_oneDownloadFailed() throws Exception {
307     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
308     modelFile1.createNewFile();
309     registerManifestToDB(
310         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
311 
312     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
313         ImmutableMap.of(
314             ModelType.ANNOTATOR,
315             ManifestsToDownloadByType.create(
316                 ImmutableMap.of("es", "manifestUrl2", "en", "manifestUrl3")));
317     File modelFile2 = new File(modelDownloaderDir, "modelFile2");
318     modelFile2.createNewFile();
319     registerManifestToDB(
320         ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
321     downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl3");
322     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
323 
324     assertThat(modelFile1.exists()).isTrue();
325     assertThat(modelFile2.exists()).isTrue();
326     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
327         .containsExactly(modelFile1, modelFile2);
328   }
329 
330   @Test
onDownoadCompleted_multipleLocales_replaceOldModel()331   public void onDownoadCompleted_multipleLocales_replaceOldModel() throws Exception {
332     File modelFile1 = new File(modelDownloaderDir, "modelFile1");
333     modelFile1.createNewFile();
334     registerManifestToDB(
335         ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
336 
337     ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload =
338         ImmutableMap.of(
339             ModelType.ANNOTATOR,
340             ManifestsToDownloadByType.create(
341                 ImmutableMap.of("en", "manifestUrl2", "es", "manifestUrl3")));
342 
343     File modelFile2 = new File(modelDownloaderDir, "modelFile2");
344     modelFile2.createNewFile();
345     registerManifestToDB(
346         ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
347 
348     File modelFile3 = new File(modelDownloaderDir, "modelFile3");
349     modelFile3.createNewFile();
350     registerManifestToDB(
351         ModelType.ANNOTATOR, "es", "manifestUrl3", "modelUrl3", modelFile3.getAbsolutePath());
352 
353     downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload);
354     assertThat(modelFile2.exists()).isTrue();
355     assertThat(modelFile3.exists()).isTrue();
356     assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
357         .containsExactly(modelFile2, modelFile3);
358   }
359 
registerManifestToDB( @odelTypeDef String modelType, String localeTag, String manifestUrl, String modelUrl, String modelPath)360   private void registerManifestToDB(
361       @ModelTypeDef String modelType,
362       String localeTag,
363       String manifestUrl,
364       String modelUrl,
365       String modelPath) {
366     db.dao().insert(Model.create(modelUrl, modelPath));
367     db.dao()
368         .insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
369     db.dao().insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
370     db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
371   }
372 }
373