/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.textclassifier.downloader; import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.when; import android.content.Context; import android.os.LocaleList; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.work.WorkInfo; import androidx.work.WorkManager; import androidx.work.testing.WorkManagerTestInitHelper; import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled; import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled.ReasonToSchedule; import com.android.textclassifier.common.ModelType; import com.android.textclassifier.common.TextClassifierSettings; import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule; import com.android.textclassifier.testing.SetDefaultLocalesRule; import com.android.textclassifier.testing.TestingDeviceConfig; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.common.util.concurrent.MoreExecutors; import java.io.File; import java.util.List; import java.util.Locale; import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @RunWith(AndroidJUnit4.class) public final class ModelDownloadManagerTest { private static final String MODEL_PATH = "/data/test.model"; @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR; private static final String LOCALE_TAG = "en"; private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG)); @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule(); @Rule public final TextClassifierDownloadLoggerTestRule loggerTestRule = new TextClassifierDownloadLoggerTestRule(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private TestingDeviceConfig deviceConfig; private WorkManager workManager; private ModelDownloadManager downloadManager; private ModelDownloadManager downloadManagerWithBadWorkManager; @Mock DownloadedModelManager downloadedModelManager; @Before public void setUp() { Context context = ApplicationProvider.getApplicationContext(); WorkManagerTestInitHelper.initializeTestWorkManager(context); this.deviceConfig = new TestingDeviceConfig(); this.workManager = WorkManager.getInstance(context); this.downloadManager = new ModelDownloadManager( context, ModelDownloadWorker.class, () -> workManager, downloadedModelManager, new TextClassifierSettings(deviceConfig), MoreExecutors.newDirectExecutorService()); this.downloadManagerWithBadWorkManager = new ModelDownloadManager( context, ModelDownloadWorker.class, () -> { throw new IllegalStateException("WorkManager may fail!"); }, downloadedModelManager, new TextClassifierSettings(deviceConfig), MoreExecutors.newDirectExecutorService()); setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST); deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true); } @After public void tearDown() { workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME); DownloaderTestUtils.deleteRecursively( ApplicationProvider.getApplicationContext().getFilesDir()); } @Test public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception { assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManagerWithBadWorkManager.onTextClassifierServiceCreated(); // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED); assertThat(atom.getFailedToSchedule()).isTrue(); } @Test public void onTextClassifierServiceCreated_requestEnqueued() throws Exception { assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); downloadManager.onTextClassifierServiceCreated(); WorkInfo workInfo = Iterables.getOnlyElement( DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test public void onTextClassifierServiceCreated_localeListOverridden() throws Exception { assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); downloadManager.onTextClassifierServiceCreated(); assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED); } @Test public void onLocaleChanged_workManagerCrashed() throws Exception { downloadManagerWithBadWorkManager.onLocaleChanged(); TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); assertThat(atom.getFailedToSchedule()).isTrue(); } @Test public void onLocaleChanged_requestEnqueued() throws Exception { downloadManager.onLocaleChanged(); WorkInfo workInfo = Iterables.getOnlyElement( DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); verifyWorkScheduledLogging(ReasonToSchedule.LOCALE_SETTINGS_CHANGED); } @Test public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception { downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged(); TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED); assertThat(atom.getFailedToSchedule()).isTrue(); } @Test public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception { downloadManager.onTextClassifierDeviceConfigChanged(); WorkInfo workInfo = Iterables.getOnlyElement( DownloaderTestUtils.queryWorkInfos( workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)); assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED); verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED); } @Test public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception { deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false); downloadManager.onTextClassifierDeviceConfigChanged(); assertThat( DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME)) .isEmpty(); assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty(); } @Test public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception { downloadManager.onTextClassifierDeviceConfigChanged(); downloadManager.onTextClassifierDeviceConfigChanged(); List workInfos = DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME); assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList())) .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED); List atoms = loggerTestRule.getLoggedDownloadWorkScheduledAtoms(); assertThat(atoms).hasSize(2); verifyWorkScheduledAtom(atoms.get(0), ReasonToSchedule.DEVICE_CONFIG_UPDATED); verifyWorkScheduledAtom(atoms.get(1), ReasonToSchedule.DEVICE_CONFIG_UPDATED); } @Test public void onTextClassifierDeviceConfigChanged_localeListOverridden() throws Exception { deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr"); downloadManager.onTextClassifierDeviceConfigChanged(); assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh")); assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr")); verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED); } @Test public void listDownloadedModels() throws Exception { File modelFile = new File(MODEL_PATH); when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile)); assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile); } @Test public void listDownloadedModels_doNotCrashOnError() throws Exception { when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException()); assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty(); } private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception { TextClassifierDownloadWorkScheduled atom = Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()); verifyWorkScheduledAtom(atom, reasonToSchedule); } private void verifyWorkScheduledAtom( TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule) { assertThat(atom.getReasonToSchedule()).isEqualTo(reasonToSchedule); assertThat(atom.getFailedToSchedule()).isFalse(); } }