• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.textclassifier.downloader;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.mockito.Mockito.when;
21 
22 import android.content.Context;
23 import android.os.LocaleList;
24 import androidx.test.core.app.ApplicationProvider;
25 import androidx.test.ext.junit.runners.AndroidJUnit4;
26 import androidx.work.WorkInfo;
27 import androidx.work.WorkManager;
28 import androidx.work.testing.WorkManagerTestInitHelper;
29 import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled;
30 import com.android.os.AtomsProto.TextClassifierDownloadWorkScheduled.ReasonToSchedule;
31 import com.android.textclassifier.common.ModelType;
32 import com.android.textclassifier.common.TextClassifierSettings;
33 import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
34 import com.android.textclassifier.testing.SetDefaultLocalesRule;
35 import com.android.textclassifier.testing.TestingDeviceConfig;
36 import com.google.common.collect.ImmutableList;
37 import com.google.common.collect.Iterables;
38 import com.google.common.util.concurrent.MoreExecutors;
39 import java.io.File;
40 import java.util.List;
41 import java.util.Locale;
42 import java.util.stream.Collectors;
43 import org.junit.After;
44 import org.junit.Before;
45 import org.junit.Rule;
46 import org.junit.Test;
47 import org.junit.runner.RunWith;
48 import org.mockito.Mock;
49 import org.mockito.junit.MockitoJUnit;
50 import org.mockito.junit.MockitoRule;
51 
52 @RunWith(AndroidJUnit4.class)
53 public final class ModelDownloadManagerTest {
54   private static final String MODEL_PATH = "/data/test.model";
55   @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
56   private static final String LOCALE_TAG = "en";
57   private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
58 
59   @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
60 
61   @Rule
62   public final TextClassifierDownloadLoggerTestRule loggerTestRule =
63       new TextClassifierDownloadLoggerTestRule();
64 
65   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
66 
67   private TestingDeviceConfig deviceConfig;
68   private WorkManager workManager;
69   private ModelDownloadManager downloadManager;
70   private ModelDownloadManager downloadManagerWithBadWorkManager;
71   @Mock DownloadedModelManager downloadedModelManager;
72 
73   @Before
setUp()74   public void setUp() {
75     Context context = ApplicationProvider.getApplicationContext();
76     WorkManagerTestInitHelper.initializeTestWorkManager(context);
77 
78     this.deviceConfig = new TestingDeviceConfig();
79     this.workManager = WorkManager.getInstance(context);
80     this.downloadManager =
81         new ModelDownloadManager(
82             context,
83             ModelDownloadWorker.class,
84             () -> workManager,
85             downloadedModelManager,
86             new TextClassifierSettings(deviceConfig),
87             MoreExecutors.newDirectExecutorService());
88     this.downloadManagerWithBadWorkManager =
89         new ModelDownloadManager(
90             context,
91             ModelDownloadWorker.class,
92             () -> {
93               throw new IllegalStateException("WorkManager may fail!");
94             },
95             downloadedModelManager,
96             new TextClassifierSettings(deviceConfig),
97             MoreExecutors.newDirectExecutorService());
98 
99     setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
100     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
101   }
102 
103   @After
tearDown()104   public void tearDown() {
105     workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME);
106     DownloaderTestUtils.deleteRecursively(
107         ApplicationProvider.getApplicationContext().getFilesDir());
108   }
109 
110   @Test
onTextClassifierServiceCreated_workManagerCrashed()111   public void onTextClassifierServiceCreated_workManagerCrashed() throws Exception {
112     assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
113     downloadManagerWithBadWorkManager.onTextClassifierServiceCreated();
114 
115     // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
116     TextClassifierDownloadWorkScheduled atom =
117         Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
118     assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.TCS_STARTED);
119     assertThat(atom.getFailedToSchedule()).isTrue();
120   }
121 
122   @Test
onTextClassifierServiceCreated_requestEnqueued()123   public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
124     assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
125     downloadManager.onTextClassifierServiceCreated();
126 
127     WorkInfo workInfo =
128         Iterables.getOnlyElement(
129             DownloaderTestUtils.queryWorkInfos(
130                 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
131     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
132     // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
133     verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
134   }
135 
136   @Test
onTextClassifierServiceCreated_localeListOverridden()137   public void onTextClassifierServiceCreated_localeListOverridden() throws Exception {
138     assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
139     deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
140     downloadManager.onTextClassifierServiceCreated();
141 
142     assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
143     assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
144     assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
145     // Assertion below is flaky: DeviceConfig listener may be trigerred by OS during test
146     verifyWorkScheduledLogging(ReasonToSchedule.TCS_STARTED);
147   }
148 
149   @Test
onLocaleChanged_workManagerCrashed()150   public void onLocaleChanged_workManagerCrashed() throws Exception {
151     downloadManagerWithBadWorkManager.onLocaleChanged();
152 
153     TextClassifierDownloadWorkScheduled atom =
154         Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
155     assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
156     assertThat(atom.getFailedToSchedule()).isTrue();
157   }
158 
159   @Test
onLocaleChanged_requestEnqueued()160   public void onLocaleChanged_requestEnqueued() throws Exception {
161     downloadManager.onLocaleChanged();
162 
163     WorkInfo workInfo =
164         Iterables.getOnlyElement(
165             DownloaderTestUtils.queryWorkInfos(
166                 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
167     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
168     verifyWorkScheduledLogging(ReasonToSchedule.LOCALE_SETTINGS_CHANGED);
169   }
170 
171   @Test
onTextClassifierDeviceConfigChanged_workManagerCrashed()172   public void onTextClassifierDeviceConfigChanged_workManagerCrashed() throws Exception {
173     downloadManagerWithBadWorkManager.onTextClassifierDeviceConfigChanged();
174 
175     TextClassifierDownloadWorkScheduled atom =
176         Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
177     assertThat(atom.getReasonToSchedule()).isEqualTo(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
178     assertThat(atom.getFailedToSchedule()).isTrue();
179   }
180 
181   @Test
onTextClassifierDeviceConfigChanged_requestEnqueued()182   public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
183     downloadManager.onTextClassifierDeviceConfigChanged();
184 
185     WorkInfo workInfo =
186         Iterables.getOnlyElement(
187             DownloaderTestUtils.queryWorkInfos(
188                 workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
189     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
190     verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
191   }
192 
193   @Test
onTextClassifierDeviceConfigChanged_downloaderDisabled()194   public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception {
195     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
196     downloadManager.onTextClassifierDeviceConfigChanged();
197 
198     assertThat(
199             DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME))
200         .isEmpty();
201     assertThat(loggerTestRule.getLoggedDownloadWorkScheduledAtoms()).isEmpty();
202   }
203 
204   @Test
onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork()205   public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception {
206     downloadManager.onTextClassifierDeviceConfigChanged();
207     downloadManager.onTextClassifierDeviceConfigChanged();
208     List<WorkInfo> workInfos =
209         DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME);
210 
211     assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList()))
212         .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED);
213     List<TextClassifierDownloadWorkScheduled> atoms =
214         loggerTestRule.getLoggedDownloadWorkScheduledAtoms();
215     assertThat(atoms).hasSize(2);
216     verifyWorkScheduledAtom(atoms.get(0), ReasonToSchedule.DEVICE_CONFIG_UPDATED);
217     verifyWorkScheduledAtom(atoms.get(1), ReasonToSchedule.DEVICE_CONFIG_UPDATED);
218   }
219 
220   @Test
onTextClassifierDeviceConfigChanged_localeListOverridden()221   public void onTextClassifierDeviceConfigChanged_localeListOverridden() throws Exception {
222     deviceConfig.setConfig(TextClassifierSettings.TESTING_LOCALE_LIST_OVERRIDE, "zh,fr");
223     downloadManager.onTextClassifierDeviceConfigChanged();
224 
225     assertThat(Locale.getDefault()).isEqualTo(Locale.forLanguageTag("zh"));
226     assertThat(LocaleList.getDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
227     assertThat(LocaleList.getAdjustedDefault()).isEqualTo(LocaleList.forLanguageTags("zh,fr"));
228     verifyWorkScheduledLogging(ReasonToSchedule.DEVICE_CONFIG_UPDATED);
229   }
230 
231   @Test
listDownloadedModels()232   public void listDownloadedModels() throws Exception {
233     File modelFile = new File(MODEL_PATH);
234     when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile));
235 
236     assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
237   }
238 
239   @Test
listDownloadedModels_doNotCrashOnError()240   public void listDownloadedModels_doNotCrashOnError() throws Exception {
241     when(downloadedModelManager.listModels(MODEL_TYPE)).thenThrow(new IllegalStateException());
242 
243     assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).isEmpty();
244   }
245 
verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule)246   private void verifyWorkScheduledLogging(ReasonToSchedule reasonToSchedule) throws Exception {
247     TextClassifierDownloadWorkScheduled atom =
248         Iterables.getOnlyElement(loggerTestRule.getLoggedDownloadWorkScheduledAtoms());
249     verifyWorkScheduledAtom(atom, reasonToSchedule);
250   }
251 
verifyWorkScheduledAtom( TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule)252   private void verifyWorkScheduledAtom(
253       TextClassifierDownloadWorkScheduled atom, ReasonToSchedule reasonToSchedule) {
254     assertThat(atom.getReasonToSchedule()).isEqualTo(reasonToSchedule);
255     assertThat(atom.getFailedToSchedule()).isFalse();
256   }
257 }
258