1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.textclassifier.downloader; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.mockito.Mockito.when; 21 22 import android.content.Context; 23 import android.os.LocaleList; 24 import androidx.test.core.app.ApplicationProvider; 25 import androidx.test.ext.junit.runners.AndroidJUnit4; 26 import androidx.work.WorkInfo; 27 import androidx.work.WorkManager; 28 import androidx.work.testing.WorkManagerTestInitHelper; 29 import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled; 30 import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled.ReasonToSchedule; 31 import com.android.textclassifier.common.ModelType; 32 import com.android.textclassifier.common.TextClassifierSettings; 33 import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule; 34 import com.android.textclassifier.testing.SetDefaultLocalesRule; 35 import com.android.textclassifier.testing.TestingDeviceConfig; 36 import com.google.common.collect.ImmutableList; 37 import com.google.common.collect.Iterables; 38 import com.google.common.util.concurrent.MoreExecutors; 39 import java.io.File; 40 import java.util.List; 41 import java.util.Locale; 42 import java.util.stream.Collectors; 43 import org.junit.After; 44 import org.junit.Before; 45 import org.junit.Rule; 46 import org.junit.Test; 47 import org.junit.runner.RunWith; 48 import org.mockito.Mock; 49 import org.mockito.junit.MockitoJUnit; 50 import org.mockito.junit.MockitoRule; 51 52 @RunWith(AndroidJUnit4.class) 53 public final class ModelDownloadManagerTest { 54 private static final String MODEL_PATH = "/data/test.model"; 55 @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR; 56 private static final String LOCALE_TAG = "en"; 57 private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG)); 58 59 @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule(); 60 61 @Rule 62 public final TextClassifierDownloadLoggerTestRule loggerTestRule = 63 new TextClassifierDownloadLoggerTestRule(); 64 65 @Rule public final MockitoRule mocks = MockitoJUnit.rule(); 66 67 private TestingDeviceConfig deviceConfig; 68 private WorkManager workManager; 69 private ModelDownloadManager downloadManager; 70 private ModelDownloadManager downloadManagerWithBadWorkManager; 71 @Mock DownloadedModelManager downloadedModelManager; 72 73 @Before setUp()74 public void setUp() { 75 Context context = ApplicationProvider.getApplicationContext(); 76 WorkManagerTestInitHelper.initializeTestWorkManager(context); 77 78 this.deviceConfig = new TestingDeviceConfig(); 79 this.workManager = WorkManager.getInstance(context); 80 this.downloadManager = 81 new ModelDownloadManager( 82 context, 83 ModelDownloadWorker.class, 84 () -> workManager, 85 downloadedModelManager, 86 new TextClassifierSettings(deviceConfig), 87 MoreExecutors.newDirectExecutorService()); 88 this.downloadManagerWithBadWorkManager = 89 new ModelDownloadManager( 90 context, 91 ModelDownloadWorker.class, 92 () -> { 93 throw new IllegalStateException("WorkManager may fail!"); 94 }, 95 downloadedModelManager, 96 new TextClassifierSettings(deviceConfig), 97 MoreExecutors.newDirectExecutorService()); 98 99 setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST); 100 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); 101 } 102 103 @After tearDown()104 public void tearDown() { 105 workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME); 106 DownloaderTestUtils.deleteRecursively( 107 ApplicationProvider.getApplicationContext().getFilesDir()); 108 } 109 110 @Test onTextClassifierServiceCreated_workManagerCrashed()111 public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception { 112 assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); 113 downloadManagerWithBadWorkManager.onTextClassifierServiceCreated(); 114 115 // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test 116 TextClassifierDownloadWorkScheduled atom = 117 Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); 118 assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED); 119 assertThat(atom.getFailedToSchedule()).isTrue(); 120 } 121 122 @Test onTextClassifierServiceCreated_requestEnqueued()123 public void onTextClassifierServiceCreated_requestEnqueued() throws Exception { 124 assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); 125 downloadManager.onTextClassifierServiceCreated(); 126 127 WorkInfo workInfo = 128 Iterables.getOnlyElement( 129 DownloaderTestUtils.queryWorkInfos( 130 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); 131 assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); 132 // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test 133 verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); 134 } 135 136 @Test onTextClassifierServiceCreated_localeListOverridden()137 public void onTextClassifierServiceCreated_localeListOverridden() throws Exception { 138 assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); 139 deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); 140 downloadManager.onTextClassifierServiceCreated(); 141 142 assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); 143 assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); 144 assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); 145 // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test 146 verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); 147 } 148 149 @Test onLocaleChanged_workManagerCrashed()150 public void onLocaleChanged_workManagerCrashed() throws Exception { 151 downloadManagerWithBadWorkManager.onLocaleChanged(); 152 153 TextClassifierDownloadWorkScheduled atom = 154 Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); 155 assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); 156 assertThat(atom.getFailedToSchedule()).isTrue(); 157 } 158 159 @Test onLocaleChanged_requestEnqueued()160 public void onLocaleChanged_requestEnqueued() throws Exception { 161 downloadManager.onLocaleChanged(); 162 163 WorkInfo workInfo = 164 Iterables.getOnlyElement( 165 DownloaderTestUtils.queryWorkInfos( 166 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); 167 assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); 168 verifyWorkScheduledLogging(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); 169 } 170 171 @Test onTextClassifierDeviceConfigChanged_workManagerCrashed()172 public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception { 173 downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged(); 174 175 TextClassifierDownloadWorkScheduled atom = 176 Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); 177 assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED); 178 assertThat(atom.getFailedToSchedule()).isTrue(); 179 } 180 181 @Test onTextClassifierDeviceConfigChanged_requestEnqueued()182 public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception { 183 downloadManager.onTextClassifierDeviceConfigChanged(); 184 185 WorkInfo workInfo = 186 Iterables.getOnlyElement( 187 DownloaderTestUtils.queryWorkInfos( 188 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); 189 assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); 190 verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED); 191 } 192 193 @Test onTextClassifierDeviceConfigChanged_downloaderDisabled()194 public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception { 195 deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false); 196 downloadManager.onTextClassifierDeviceConfigChanged(); 197 198 assertThat( 199 DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)) 200 .isEmpty(); 201 assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); 202 } 203 204 @Test onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork()205 public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception { 206 downloadManager.onTextClassifierDeviceConfigChanged(); 207 downloadManager.onTextClassifierDeviceConfigChanged(); 208 List<WorkInfo> workInfos = 209 DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME); 210 211 assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList())) 212 .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED); 213 List<TextClassifierDownloadWorkScheduled> atoms = 214 loggerTestRule.getLoggedDownloadWorkScheduledAtoms(); 215 assertThat(atoms).hasSize(2); 216 verifyWorkScheduledAtom(atoms.get(0), ReasonToSchedule.DEVICE_CONFIG_UPDATED); 217 verifyWorkScheduledAtom(atoms.get(1), ReasonToSchedule.DEVICE_CONFIG_UPDATED); 218 } 219 220 @Test onTextClassifierDeviceConfigChanged_localeListOverridden()221 public void onTextClassifierDeviceConfigChanged_localeListOverridden() throws Exception { 222 deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); 223 downloadManager.onTextClassifierDeviceConfigChanged(); 224 225 assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); 226 assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); 227 assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); 228 verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED); 229 } 230 231 @Test listDownloadedModels()232 public void listDownloadedModels() throws Exception { 233 File modelFile = new File(MODEL_PATH); 234 when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile)); 235 236 assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile); 237 } 238 239 @Test listDownloadedModels_doNotCrashOnError()240 public void listDownloadedModels_doNotCrashOnError() throws Exception { 241 when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException()); 242 243 assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty(); 244 } 245 verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule)246 private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception { 247 TextClassifierDownloadWorkScheduled atom = 248 Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); 249 verifyWorkScheduledAtom(atom, reasonToSchedule); 250 } 251 verifyWorkScheduledAtom( TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule)252 private void verifyWorkScheduledAtom( 253 TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule) { 254 assertThat(atom.getReasonToSchedule()).isEqualTo(reasonToSchedule); 255 assertThat(atom.getFailedToSchedule()).isFalse(); 256 } 257 } 258