1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import static com.google.common.truth.Truth.assertThat; 20 21 import android.content.Context; 22 import androidx.room.Room; 23 import androidx.test.core.app.ApplicationProvider; 24 import com.android.textclassifier.common.ModelType; 25 import com.android.textclassifier.common.ModelType.ModelTypeDef; 26 import com.android.textclassifier.common.TextClassifierSettings; 27 import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest; 28 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment; 29 import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef; 30 import com.android.textclassifier.downloader.DownloadedModelDatabase.Model; 31 import com.android.textclassifier.testing.TestingDeviceConfig; 32 import com.google.common.collect.ImmutableMap; 33 import java.io.File; 34 import org.junit.After; 35 import org.junit.Before; 36 import org.junit.Test; 37 import org.junit.runner.RunWith; 38 import org.junit.runners.JUnit4; 39 40 @RunWith(JUnit4.class) 41 public final class DownloadedModelManagerImplTest { 42 43 private File modelDownloaderDir; 44 private DownloadedModelDatabase db; 45 private DownloadedModelManagerImpl downloadedModelManagerImpl; 46 private TestingDeviceConfig deviceConfig; 47 private TextClassifierSettings settings; 48 49 @Before setUp()50 public void setUp() { 51 Context context = ApplicationProvider.getApplicationContext(); 52 modelDownloaderDir = new File(context.getFilesDir(), "test_dir"); 53 modelDownloaderDir.mkdirs(); 54 deviceConfig = new TestingDeviceConfig(); 55 settings = new TextClassifierSettings(deviceConfig, /* isWear= */ false); 56 db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build(); 57 downloadedModelManagerImpl = 58 DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings); 59 } 60 61 @After cleanUp()62 public void cleanUp() { 63 DownloaderTestUtils.deleteRecursively(modelDownloaderDir); 64 db.close(); 65 } 66 67 @Test getModelDownloaderDir()68 public void getModelDownloaderDir() throws Exception { 69 modelDownloaderDir.delete(); 70 assertThat(downloadedModelManagerImpl.getModelDownloaderDir().exists()).isTrue(); 71 assertThat(downloadedModelManagerImpl.getModelDownloaderDir()).isEqualTo(modelDownloaderDir); 72 } 73 74 @Test listModels_cacheNotInitialized()75 public void listModels_cacheNotInitialized() throws Exception { 76 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn"); 77 registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh"); 78 79 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 80 .containsExactly(new File("modelPathEn"), new File("modelPathZh")); 81 assertThat(downloadedModelManagerImpl.listModels(ModelType.LANG_ID)).isEmpty(); 82 } 83 84 @Test listModels_doNotListBlockedModels()85 public void listModels_doNotListBlockedModels() throws Exception { 86 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn"); 87 registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh"); 88 deviceConfig.setConfig( 89 TextClassifierSettings.MODEL_URL_BLOCKLIST, 90 String.format( 91 "%s%s%s", 92 "modelUrlEn", TextClassifierSettings.MODEL_URL_BLOCKLIST_SEPARATOR, "modelUrlXX")); 93 94 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 95 .containsExactly(new File("modelPathZh")); 96 } 97 98 @Test listModels_cacheNotUpdatedUnlessOnDownloadCompleted()99 public void listModels_cacheNotUpdatedUnlessOnDownloadCompleted() throws Exception { 100 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn"); 101 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 102 .containsExactly(new File("modelPathEn")); 103 104 registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh"); 105 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 106 .containsExactly(new File("modelPathEn")); 107 108 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 109 ImmutableMap.of( 110 ModelType.ANNOTATOR, 111 ManifestsToDownloadByType.create(ImmutableMap.of("zh", "manifestUrlZh"))); 112 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 113 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 114 .contains(new File("modelPathZh")); 115 } 116 117 @Test getModel()118 public void getModel() throws Exception { 119 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath"); 120 assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath()) 121 .isEqualTo("modelPath"); 122 assertThat(downloadedModelManagerImpl.getModel("modelUrl2")).isNull(); 123 } 124 125 @Test getManifest()126 public void getManifest() throws Exception { 127 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath"); 128 assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull(); 129 assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull(); 130 } 131 132 @Test getManifestEnrollment()133 public void getManifestEnrollment() throws Exception { 134 registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath"); 135 assertThat( 136 downloadedModelManagerImpl 137 .getManifestEnrollment(ModelType.ANNOTATOR, "en") 138 .getManifestUrl()) 139 .isEqualTo("manifestUrl"); 140 assertThat(downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "zh")) 141 .isNull(); 142 } 143 144 @Test registerModel()145 public void registerModel() throws Exception { 146 downloadedModelManagerImpl.registerModel("modelUrl", "modelPath"); 147 148 assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath()) 149 .isEqualTo("modelPath"); 150 } 151 152 @Test registerManifest()153 public void registerManifest() throws Exception { 154 downloadedModelManagerImpl.registerModel("modelUrl", "modelPath"); 155 downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl"); 156 157 assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull(); 158 } 159 160 @Test registerManifestDownloadFailure()161 public void registerManifestDownloadFailure() throws Exception { 162 downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl"); 163 164 Manifest manifest = downloadedModelManagerImpl.getManifest("manifestUrl"); 165 assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED); 166 assertThat(manifest.getFailureCounts()).isEqualTo(1); 167 } 168 169 @Test registerManifestEnrollment()170 public void registerManifestEnrollment() throws Exception { 171 downloadedModelManagerImpl.registerModel("modelUrl", "modelPath"); 172 downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl"); 173 downloadedModelManagerImpl.registerManifestEnrollment(ModelType.ANNOTATOR, "en", "manifestUrl"); 174 175 ManifestEnrollment manifestEnrollment = 176 downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "en"); 177 assertThat(manifestEnrollment.getModelType()).isEqualTo(ModelType.ANNOTATOR); 178 assertThat(manifestEnrollment.getLocaleTag()).isEqualTo("en"); 179 assertThat(manifestEnrollment.getManifestUrl()).isEqualTo("manifestUrl"); 180 } 181 182 @Test onDownloadCompleted_newModelDownloaded()183 public void onDownloadCompleted_newModelDownloaded() throws Exception { 184 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 185 ImmutableMap.of( 186 ModelType.ANNOTATOR, 187 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1"))); 188 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 189 modelFile1.createNewFile(); 190 registerManifestToDB( 191 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 192 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 193 194 assertThat(modelFile1.exists()).isTrue(); 195 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 196 .containsExactly(modelFile1); 197 198 manifestsToDownload = 199 ImmutableMap.of( 200 ModelType.ANNOTATOR, 201 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2"))); 202 File modelFile2 = new File(modelDownloaderDir, "modelFile2"); 203 modelFile2.createNewFile(); 204 registerManifestToDB( 205 ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath()); 206 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 207 208 assertThat(modelFile1.exists()).isFalse(); 209 assertThat(modelFile2.exists()).isTrue(); 210 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 211 .containsExactly(modelFile2); 212 } 213 214 @Test onDownloadCompleted_newModelDownloadFailed()215 public void onDownloadCompleted_newModelDownloadFailed() throws Exception { 216 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 217 ImmutableMap.of( 218 ModelType.ANNOTATOR, 219 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1"))); 220 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 221 modelFile1.createNewFile(); 222 registerManifestToDB( 223 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 224 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 225 226 assertThat(modelFile1.exists()).isTrue(); 227 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 228 .containsExactly(modelFile1); 229 230 manifestsToDownload = 231 ImmutableMap.of( 232 ModelType.ANNOTATOR, 233 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl2"))); 234 downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2"); 235 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 236 237 assertThat(modelFile1.exists()).isTrue(); 238 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 239 .containsExactly(modelFile1); 240 } 241 242 @Test onDownloadCompleted_flatUnset()243 public void onDownloadCompleted_flatUnset() throws Exception { 244 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 245 ImmutableMap.of( 246 ModelType.ANNOTATOR, 247 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1"))); 248 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 249 modelFile1.createNewFile(); 250 registerManifestToDB( 251 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 252 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 253 254 assertThat(modelFile1.exists()).isTrue(); 255 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 256 .containsExactly(modelFile1); 257 258 manifestsToDownload = ImmutableMap.of(); 259 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 260 261 assertThat(modelFile1.exists()).isFalse(); 262 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)).isEmpty(); 263 } 264 265 @Test onDownloadCompleted_cleanUpFailureRecords()266 public void onDownloadCompleted_cleanUpFailureRecords() throws Exception { 267 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 268 ImmutableMap.of( 269 ModelType.ANNOTATOR, 270 ManifestsToDownloadByType.create(ImmutableMap.of("en", "manifestUrl1"))); 271 downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl1"); 272 downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2"); 273 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 274 275 assertThat(downloadedModelManagerImpl.getManifest("manifestUrl1").getStatus()) 276 .isEqualTo(Manifest.STATUS_FAILED); 277 assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull(); 278 } 279 280 @Test onDownloadCompleted_modelsForMultipleLocalesDownloaded()281 public void onDownloadCompleted_modelsForMultipleLocalesDownloaded() throws Exception { 282 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 283 ImmutableMap.of( 284 ModelType.ANNOTATOR, 285 ManifestsToDownloadByType.create( 286 ImmutableMap.of("en", "manifestUrl1", "es", "manifestUrl2"))); 287 288 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 289 modelFile1.createNewFile(); 290 registerManifestToDB( 291 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 292 293 File modelFile2 = new File(modelDownloaderDir, "modelFile2"); 294 modelFile2.createNewFile(); 295 registerManifestToDB( 296 ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath()); 297 298 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 299 assertThat(modelFile1.exists()).isTrue(); 300 assertThat(modelFile2.exists()).isTrue(); 301 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 302 .containsExactly(modelFile1, modelFile2); 303 } 304 305 @Test onDownloadCompleted_multipleLocales_oneDownloadFailed()306 public void onDownloadCompleted_multipleLocales_oneDownloadFailed() throws Exception { 307 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 308 modelFile1.createNewFile(); 309 registerManifestToDB( 310 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 311 312 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 313 ImmutableMap.of( 314 ModelType.ANNOTATOR, 315 ManifestsToDownloadByType.create( 316 ImmutableMap.of("es", "manifestUrl2", "en", "manifestUrl3"))); 317 File modelFile2 = new File(modelDownloaderDir, "modelFile2"); 318 modelFile2.createNewFile(); 319 registerManifestToDB( 320 ModelType.ANNOTATOR, "es", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath()); 321 downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl3"); 322 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 323 324 assertThat(modelFile1.exists()).isTrue(); 325 assertThat(modelFile2.exists()).isTrue(); 326 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 327 .containsExactly(modelFile1, modelFile2); 328 } 329 330 @Test onDownoadCompleted_multipleLocales_replaceOldModel()331 public void onDownoadCompleted_multipleLocales_replaceOldModel() throws Exception { 332 File modelFile1 = new File(modelDownloaderDir, "modelFile1"); 333 modelFile1.createNewFile(); 334 registerManifestToDB( 335 ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath()); 336 337 ImmutableMap<String, ManifestsToDownloadByType> manifestsToDownload = 338 ImmutableMap.of( 339 ModelType.ANNOTATOR, 340 ManifestsToDownloadByType.create( 341 ImmutableMap.of("en", "manifestUrl2", "es", "manifestUrl3"))); 342 343 File modelFile2 = new File(modelDownloaderDir, "modelFile2"); 344 modelFile2.createNewFile(); 345 registerManifestToDB( 346 ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath()); 347 348 File modelFile3 = new File(modelDownloaderDir, "modelFile3"); 349 modelFile3.createNewFile(); 350 registerManifestToDB( 351 ModelType.ANNOTATOR, "es", "manifestUrl3", "modelUrl3", modelFile3.getAbsolutePath()); 352 353 downloadedModelManagerImpl.onDownloadCompleted(manifestsToDownload); 354 assertThat(modelFile2.exists()).isTrue(); 355 assertThat(modelFile3.exists()).isTrue(); 356 assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)) 357 .containsExactly(modelFile2, modelFile3); 358 } 359 registerManifestToDB( @odelTypeDef String modelType, String localeTag, String manifestUrl, String modelUrl, String modelPath)360 private void registerManifestToDB( 361 @ModelTypeDef String modelType, 362 String localeTag, 363 String manifestUrl, 364 String modelUrl, 365 String modelPath) { 366 db.dao().insert(Model.create(modelUrl, modelPath)); 367 db.dao() 368 .insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0)); 369 db.dao().insert(ManifestModelCrossRef.create(manifestUrl, modelUrl)); 370 db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl)); 371 } 372 } 373