• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.scheduling;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static com.google.common.truth.Truth.assertWithMessage;
21 
22 import static org.mockito.Mockito.when;
23 
24 import static java.lang.Math.min;
25 
26 import android.app.job.JobInfo;
27 import android.app.job.JobScheduler;
28 import android.content.ComponentName;
29 import android.content.Context;
30 import android.federatedcompute.aidl.IFederatedComputeCallback;
31 import android.federatedcompute.common.TrainingInterval;
32 import android.federatedcompute.common.TrainingOptions;
33 
34 import androidx.test.core.app.ApplicationProvider;
35 
36 import com.android.federatedcompute.services.common.Clock;
37 import com.android.federatedcompute.services.common.Flags;
38 import com.android.federatedcompute.services.common.TaskRetry;
39 import com.android.federatedcompute.services.common.TrainingResult;
40 import com.android.federatedcompute.services.data.FederatedTrainingTask;
41 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
42 import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper;
43 import com.android.federatedcompute.services.data.fbs.SchedulingMode;
44 import com.android.federatedcompute.services.data.fbs.SchedulingReason;
45 import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
46 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
47 
48 import com.google.flatbuffers.FlatBufferBuilder;
49 
50 import org.junit.After;
51 import org.junit.Before;
52 import org.junit.Test;
53 import org.junit.runner.RunWith;
54 import org.mockito.Mock;
55 import org.mockito.junit.MockitoJUnitRunner;
56 
57 import java.nio.ByteBuffer;
58 import java.util.List;
59 import java.util.concurrent.CountDownLatch;
60 
61 import javax.annotation.Nullable;
62 
63 @RunWith(MockitoJUnitRunner.class)
64 public final class FederatedComputeJobManagerTest {
65     private static final String POPULATION_NAME1 = "population1";
66     private static final String POPULATION_NAME2 = "population2";
67     private static final int JOB_ID1 = 700000001;
68     private static final int JOB_ID2 = 700000002;
69     private static final long DEFAULT_SCHEDULING_PERIOD_SECS = 1234;
70     private static final long DEFAULT_SCHEDULING_PERIOD_MILLIS =
71             DEFAULT_SCHEDULING_PERIOD_SECS * 1000;
72     private static final long MAX_SCHEDULING_PERIOD_SECS = 912000;
73     private static final long MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION = 604800L;
74 
75     private static final String TRAINING_JOB_SERVICE =
76             "com.android.federatedcompute.services.training.FederatedJobService";
77     private static final long CURRENT_TIME_MILLIS = 1000L;
78     private static final byte[] DEFAULT_CONSTRAINTS = createDefaultTrainingConstraints();
79     private static final TrainingOptions OPTIONS1 =
80             new TrainingOptions.Builder()
81                     .setPopulationName(POPULATION_NAME1)
82                     .setJobSchedulerJobId(JOB_ID1)
83                     .build();
84     private static final TrainingOptions OPTIONS2 =
85             new TrainingOptions.Builder()
86                     .setPopulationName(POPULATION_NAME2)
87                     .setJobSchedulerJobId(JOB_ID2)
88                     .build();
89     private static final TaskRetry TASK_RETRY =
90             new TaskRetry.Builder().setMinDelay(5000000).setMaxDelay(6000000).build();
91 
92     private FederatedComputeJobManager mJobManager;
93     private Context mContext;
94     private FederatedTrainingTaskDao mTrainingTaskDao;
95     private boolean mSuccess = false;
96     private final CountDownLatch mLatch = new CountDownLatch(1);
97 
98     @Mock private Clock mClock;
99     @Mock private Flags mMockFlags;
100     private JobScheduler mJobScheduler;
101 
102     @Before
setUp()103     public void setUp() {
104         mContext = ApplicationProvider.getApplicationContext();
105         mJobScheduler = mContext.getSystemService(JobScheduler.class);
106         mJobScheduler.cancelAll();
107         mTrainingTaskDao = FederatedTrainingTaskDao.getInstanceForTest(mContext);
108         mJobManager =
109                 new FederatedComputeJobManager(
110                         mContext,
111                         mTrainingTaskDao,
112                         new JobSchedulerHelper(mClock),
113                         mClock,
114                         mMockFlags);
115         when(mClock.currentTimeMillis()).thenReturn(CURRENT_TIME_MILLIS);
116         when(mMockFlags.getDefaultSchedulingPeriodSecs())
117                 .thenReturn(DEFAULT_SCHEDULING_PERIOD_SECS);
118         when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation())
119                 .thenReturn(MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION);
120         when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation()).thenReturn(1L);
121         when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(MAX_SCHEDULING_PERIOD_SECS);
122     }
123 
124     @After
tearDown()125     public void tearDown() {
126         // Manually clean up the database.
127         mTrainingTaskDao.clearDatabase();
128         FederatedTrainingTaskDbHelper dbHelper =
129                 FederatedTrainingTaskDbHelper.getInstanceForTest(mContext);
130         dbHelper.getWritableDatabase().close();
131         dbHelper.getReadableDatabase().close();
132         dbHelper.close();
133     }
134 
135     @Test
testOnTrainerStartCalledSuccess()136     public void testOnTrainerStartCalledSuccess() throws Exception {
137         when(mClock.currentTimeMillis()).thenReturn(1000L).thenReturn(2000L);
138 
139         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
140 
141         assertThat(mSuccess).isTrue();
142         List<FederatedTrainingTask> taskList =
143                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
144         assertThat(taskList)
145                 .containsExactly(
146                         basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null)
147                                 .creationTime(1000L)
148                                 .lastScheduledTime(1000L)
149                                 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
150                                 .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS)
151                                 .build());
152     }
153 
154     @Test
testOnTrainerStartCalled_firstTime()155     public void testOnTrainerStartCalled_firstTime() throws Exception {
156         when(mClock.currentTimeMillis()).thenReturn(1000L);
157         // Make three onTrainerStart calls, each with different job ID and session name.
158         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
159         when(mClock.currentTimeMillis()).thenReturn(2000L);
160         mJobManager.onTrainerStartCalled(OPTIONS2, new TestFederatedComputeCallback());
161         mLatch.await();
162 
163         assertThat(mSuccess).isTrue();
164         // verify training tasks in database.
165         List<FederatedTrainingTask> taskList =
166                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
167         assertThat(taskList)
168                 .containsExactly(
169                         basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null)
170                                 .creationTime(1000L)
171                                 .lastScheduledTime(1000L)
172                                 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
173                                 .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS)
174                                 .build(),
175                         basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null)
176                                 .creationTime(2000L)
177                                 .lastScheduledTime(2000L)
178                                 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
179                                 .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS)
180                                 .build());
181 
182         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2);
183         assertJobInfosMatch(
184                 mJobScheduler.getPendingJob(JOB_ID1),
185                 buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS));
186         assertJobInfosMatch(
187                 mJobScheduler.getPendingJob(JOB_ID2),
188                 buildExpectedJobInfo(JOB_ID2, DEFAULT_SCHEDULING_PERIOD_MILLIS));
189     }
190 
191     @Test
testOnTrainerStartCalledFL_withIntervalSmallerThanDefaultInterval()192     public void testOnTrainerStartCalledFL_withIntervalSmallerThanDefaultInterval()
193             throws Exception {
194         testOnTrainerStartCalledFLWithInterval(
195                 /* userDefinedIntervalMillis= */ 1000000, /* defaultIntervalMillis= */ 2000000);
196     }
197 
198     @Test
testOnTrainerStartCalledFL_withIntervalLargerThanDefaultInterval()199     public void testOnTrainerStartCalledFL_withIntervalLargerThanDefaultInterval()
200             throws Exception {
201         testOnTrainerStartCalledFLWithInterval(
202                 /* userDefinedIntervalMillis= */ 2000000, /* defaultIntervalMillis= */ 1000000);
203     }
204 
testOnTrainerStartCalledFLWithInterval( long userDefinedIntervalMillis, long defaultIntervalMillis)205     private void testOnTrainerStartCalledFLWithInterval(
206             long userDefinedIntervalMillis, long defaultIntervalMillis) throws Exception {
207         when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(defaultIntervalMillis / 1000);
208         TrainingOptions trainerOptions =
209                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
210                         .setTrainingInterval(
211                                 new TrainingInterval.Builder()
212                                         .setSchedulingMode(
213                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
214                                         .setMinimumIntervalMillis(userDefinedIntervalMillis)
215                                         .build())
216                         .build();
217         mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback());
218 
219         byte[] trainingIntervalOptions =
220                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis);
221 
222         long expectedInterval = min(userDefinedIntervalMillis, defaultIntervalMillis);
223         FederatedTrainingTask expectedTask =
224                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions)
225                         .earliestNextRunTime(CURRENT_TIME_MILLIS + expectedInterval)
226                         .lastScheduledTime(CURRENT_TIME_MILLIS)
227                         .creationTime(CURRENT_TIME_MILLIS)
228                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
229                         .build();
230         List<FederatedTrainingTask> taskList =
231                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
232 
233         assertThat(taskList).containsExactly(expectedTask);
234         assertJobInfosMatch(
235                 mJobScheduler.getPendingJob(JOB_ID1),
236                 buildExpectedJobInfo(JOB_ID1, expectedInterval));
237     }
238 
239     /**
240      * Tests onTrainerStart being called multiple times with the same parameters (the common
241      * expected use case).
242      *
243      * <p>After the first call, most fields in the task (like creation time, earliest next run time,
244      * etc.) must be preserved, and only certain fields (like last scheduled time) should be
245      * updated.
246      */
247     @Test
testOnTrainerStartCalled_multipleTimes_sameParams()248     public void testOnTrainerStartCalled_multipleTimes_sameParams() throws Exception {
249         when(mClock.currentTimeMillis()).thenReturn(1000L);
250         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
251 
252         when(mClock.currentTimeMillis()).thenReturn(2000L);
253         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
254 
255         when(mClock.currentTimeMillis()).thenReturn(3000L);
256         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
257 
258         List<FederatedTrainingTask> taskList =
259                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
260         FederatedTrainingTask expectedTask =
261                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null)
262                         .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS)
263                         .lastScheduledTime(3000L)
264                         .creationTime(1000L)
265                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
266                         .build();
267         assertThat(taskList).containsExactly(expectedTask);
268 
269         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
270         assertJobInfosMatch(
271                 mJobScheduler.getPendingJob(JOB_ID1),
272                 buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS));
273     }
274 
275     /**
276      * Tests when the user specified interval is larger than the maximum server specified interval,
277      * multiple scheduling with same user specified interval will not be incorrectly capped at the
278      * maximum server specified interval.
279      */
280     @Test
testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalLargerThanServerMax()281     public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalLargerThanServerMax()
282             throws Exception {
283         long minIntervalMills = 10000L; // 10 seconds
284         // Maximum server specified interval is 5 seconds
285         when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(5L);
286         TrainingOptions trainingOptions =
287                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
288                         .setTrainingInterval(
289                                 new TrainingInterval.Builder()
290                                         .setSchedulingMode(
291                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
292                                         .setMinimumIntervalMillis(minIntervalMills)
293                                         .build())
294                         .build();
295 
296         when(mClock.currentTimeMillis()).thenReturn(1000L);
297         mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback());
298 
299         when(mClock.currentTimeMillis()).thenReturn(2000L);
300         mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback());
301 
302         when(mClock.currentTimeMillis()).thenReturn(3000L);
303         mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback());
304 
305         List<FederatedTrainingTask> taskList =
306                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
307         byte[] expectedInterval =
308                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills);
309         FederatedTrainingTask expectedTask =
310                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval)
311                         .earliestNextRunTime(1000 + minIntervalMills)
312                         .lastScheduledTime(3000L)
313                         .creationTime(1000L)
314                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
315                         .build();
316         assertThat(taskList).containsExactly(expectedTask);
317 
318         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
319         assertJobInfosMatch(
320                 mJobScheduler.getPendingJob(JOB_ID1),
321                 buildExpectedJobInfo(JOB_ID1, minIntervalMills));
322     }
323 
324     /**
325      * Tests when a task got scheduled with the same set of parameters multiple times, the brella
326      * defined max for user specified interval has been lowered between the multiple scheduling
327      * events, the user specified interval should be always guarded with the latest max.
328      */
329     @Test
testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalDifferentMax()330     public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalDifferentMax()
331             throws Exception {
332         // Initial max 20 seconds is larger than the user specified interval.
333         when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()).thenReturn(20L);
334         long minIntervalMills = 10000L; // 10 seconds
335         TrainingOptions trainingOptions =
336                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
337                         .setTrainingInterval(
338                                 new TrainingInterval.Builder()
339                                         .setSchedulingMode(
340                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
341                                         .setMinimumIntervalMillis(minIntervalMills)
342                                         .build())
343                         .build();
344 
345         when(mClock.currentTimeMillis()).thenReturn(1000L);
346         mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback());
347 
348         List<FederatedTrainingTask> taskList =
349                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
350         byte[] expectedInterval =
351                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills);
352         FederatedTrainingTask expectedTask =
353                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval)
354                         .earliestNextRunTime(1000L + minIntervalMills)
355                         .lastScheduledTime(1000L)
356                         .creationTime(1000L)
357                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
358                         .build();
359         assertThat(taskList).containsExactly(expectedTask);
360 
361         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
362         assertJobInfosMatch(
363                 mJobScheduler.getPendingJob(JOB_ID1),
364                 buildExpectedJobInfo(JOB_ID1, minIntervalMills));
365 
366         // Now lower allowed max for the user specified interval
367         long newMaxSec = 5L;
368         long newMinIntervalMills = newMaxSec * 1000;
369         when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation())
370                 .thenReturn(newMaxSec);
371         TrainingOptions newTrainingOptions =
372                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
373                         .setTrainingInterval(
374                                 new TrainingInterval.Builder()
375                                         .setSchedulingMode(
376                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
377                                         .setMinimumIntervalMillis(newMinIntervalMills)
378                                         .build())
379                         .build();
380 
381         when(mClock.currentTimeMillis()).thenReturn(2000L);
382         mJobManager.onTrainerStartCalled(newTrainingOptions, new TestFederatedComputeCallback());
383 
384         taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null);
385         expectedInterval =
386                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, newMinIntervalMills);
387         expectedTask =
388                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval)
389                         .earliestNextRunTime(2000L + newMinIntervalMills)
390                         .lastScheduledTime(2000L)
391                         .creationTime(1000L)
392                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
393                         .build();
394         assertThat(taskList).containsExactly(expectedTask);
395         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
396         assertJobInfosMatch(
397                 mJobScheduler.getPendingJob(JOB_ID1),
398                 buildExpectedJobInfo(JOB_ID1, newMinIntervalMills));
399     }
400 
401     @Test
testOnTrainerStartCalled_fLCustomerSpecifiedIntervalSmallerThanDefinedMin()402     public void testOnTrainerStartCalled_fLCustomerSpecifiedIntervalSmallerThanDefinedMin()
403             throws Exception {
404         when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(2000L);
405         long minTrainingIntervalSecByFederatedCompute = 1800L;
406         long minTrainingIntervalMillsByFederatedCompute =
407                 minTrainingIntervalSecByFederatedCompute * 1000;
408         when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation())
409                 .thenReturn(minTrainingIntervalSecByFederatedCompute);
410 
411         TrainingOptions trainingOptions =
412                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
413                         .setTrainingInterval(
414                                 new TrainingInterval.Builder()
415                                         .setSchedulingMode(
416                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
417                                         .setMinimumIntervalMillis(1000L)
418                                         .build())
419                         .build();
420 
421         when(mClock.currentTimeMillis()).thenReturn(1000L);
422         mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback());
423 
424         List<FederatedTrainingTask> taskList =
425                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
426         byte[] expectedInterval = createTrainingIntervalOptions(SchedulingMode.RECURRENT, 1000L);
427         FederatedTrainingTask expectedTask =
428                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval)
429                         .earliestNextRunTime(1000L + minTrainingIntervalMillsByFederatedCompute)
430                         .lastScheduledTime(1000L)
431                         .creationTime(1000L)
432                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
433                         .build();
434         assertThat(taskList).containsExactly(expectedTask);
435 
436         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
437         assertJobInfosMatch(
438                 mJobScheduler.getPendingJob(JOB_ID1),
439                 buildExpectedJobInfo(JOB_ID1, minTrainingIntervalMillsByFederatedCompute));
440     }
441 
442     @Test
testOnTrainerStartCalled_trainingIntervalChange_FL()443     public void testOnTrainerStartCalled_trainingIntervalChange_FL() throws Exception {
444         when(mClock.currentTimeMillis()).thenReturn(1000L);
445         mJobManager.onTrainerStartCalled(
446                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(),
447                 new TestFederatedComputeCallback());
448 
449         long minTrainingIntervalMillis = 60000L;
450         when(mClock.currentTimeMillis()).thenReturn(2000L);
451         mJobManager.onTrainerStartCalled(
452                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
453                         .setTrainingInterval(
454                                 new TrainingInterval.Builder()
455                                         .setSchedulingMode(
456                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
457                                         .setMinimumIntervalMillis(minTrainingIntervalMillis)
458                                         .build())
459                         .build(),
460                 new TestFederatedComputeCallback());
461         byte[] trainingInterval =
462                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minTrainingIntervalMillis);
463         verifyTaskAndJobAfterIntervalChange(
464                 trainingInterval, 1000, 2000, minTrainingIntervalMillis);
465 
466         long newInterval = 70000L;
467         when(mClock.currentTimeMillis()).thenReturn(3000L);
468         mJobManager.onTrainerStartCalled(
469                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
470                         .setTrainingInterval(
471                                 new TrainingInterval.Builder()
472                                         .setSchedulingMode(
473                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
474                                         .setMinimumIntervalMillis(newInterval)
475                                         .build())
476                         .build(),
477                 new TestFederatedComputeCallback());
478         byte[] trainingIntervalOption2 =
479                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, newInterval);
480         // Verify the creation time not changed, modified time is set to now, and the min interval
481         // is set to the new interval.
482         verifyTaskAndJobAfterIntervalChange(trainingIntervalOption2, 1000, 3000, newInterval);
483 
484         when(mClock.currentTimeMillis()).thenReturn(4000L);
485         mJobManager.onTrainerStartCalled(
486                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
487                         .setTrainingInterval(
488                                 new TrainingInterval.Builder()
489                                         .setSchedulingMode(
490                                                 TrainingInterval.SCHEDULING_MODE_ONE_TIME)
491                                         .build())
492                         .build(),
493                 new TestFederatedComputeCallback());
494         byte[] trainingIntervalOption3 = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0L);
495         // Verify the creation time not changed, modified time is set to now, and the min interval
496         // is set to the new interval.
497         verifyTaskAndJobAfterIntervalChange(
498                 trainingIntervalOption3, 1000, 4000, DEFAULT_SCHEDULING_PERIOD_MILLIS);
499 
500         // Transition back to not set
501         when(mClock.currentTimeMillis()).thenReturn(5000L);
502         mJobManager.onTrainerStartCalled(
503                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(),
504                 new TestFederatedComputeCallback());
505         verifyTaskAndJobAfterIntervalChange(null, 1000, 5000, DEFAULT_SCHEDULING_PERIOD_MILLIS);
506     }
507 
verifyTaskAndJobAfterIntervalChange( @ullable byte[] trainingIntervalOptions, long createTimeMillis, long modifyTimeMillis, long expectedIntervalMillis)508     private void verifyTaskAndJobAfterIntervalChange(
509             @Nullable byte[] trainingIntervalOptions,
510             long createTimeMillis,
511             long modifyTimeMillis,
512             long expectedIntervalMillis)
513             throws Exception {
514         List<FederatedTrainingTask> taskList =
515                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
516         FederatedTrainingTask expectedTask =
517                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions)
518                         .earliestNextRunTime(modifyTimeMillis + expectedIntervalMillis)
519                         .lastScheduledTime(modifyTimeMillis)
520                         .creationTime(createTimeMillis)
521                         .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
522                         .build();
523         assertThat(taskList).containsExactly(expectedTask);
524 
525         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
526         assertJobInfosMatch(
527                 mJobScheduler.getPendingJob(JOB_ID1),
528                 buildExpectedJobInfo(JOB_ID1, expectedIntervalMillis));
529     }
530 
531     @Test
testOnTrainerStartCalled_multipleTimes_changingPopulationName()532     public void testOnTrainerStartCalled_multipleTimes_changingPopulationName() throws Exception {
533         // Only change the population name.
534         int jobId = JOB_ID1;
535         doTestOnTrainerStartCalled_multipleTimes_changingParams(
536                 jobId,
537                 POPULATION_NAME1,
538                 jobId,
539                 POPULATION_NAME2,
540                 SchedulingReason.SCHEDULING_REASON_NEW_TASK);
541     }
542 
543     @Test
testOnTrainerStartCalled_twoJobsWithSamePopulationName()544     public void testOnTrainerStartCalled_twoJobsWithSamePopulationName() throws Exception {
545         // Change both the job ID and session name between Trainer.start calls.
546         doTestOnTrainerStartCalled_multipleTimes_changingParams(
547                 JOB_ID1,
548                 POPULATION_NAME1,
549                 JOB_ID2,
550                 POPULATION_NAME1,
551                 SchedulingReason.SCHEDULING_REASON_NEW_TASK);
552     }
553 
554     @Test
testOnTrainerStartCalled_multipleTimes_changingJobId()555     public void testOnTrainerStartCalled_multipleTimes_changingJobId() throws Exception {
556         // Only change the job ID.
557         String populationName = POPULATION_NAME1;
558         doTestOnTrainerStartCalled_multipleTimes_changingParams(
559                 JOB_ID1,
560                 populationName,
561                 JOB_ID2,
562                 populationName,
563                 SchedulingReason.SCHEDULING_REASON_NEW_TASK);
564     }
565 
doTestOnTrainerStartCalled_multipleTimes_changingParams( int jobId1, String populationName1, int jobId2, String populationName2, int expectedSchedulingReason)566     private void doTestOnTrainerStartCalled_multipleTimes_changingParams(
567             int jobId1,
568             String populationName1,
569             int jobId2,
570             String populationName2,
571             int expectedSchedulingReason)
572             throws Exception {
573         when(mClock.currentTimeMillis()).thenReturn(1000L);
574         TrainingOptions options1 =
575                 new TrainingOptions.Builder()
576                         .setPopulationName(populationName1)
577                         .setJobSchedulerJobId(jobId1)
578                         .build();
579         mJobManager.onTrainerStartCalled(options1, new TestFederatedComputeCallback());
580 
581         // Pass in a new population name.
582         when(mClock.currentTimeMillis()).thenReturn(2000L);
583         TrainingOptions options2 =
584                 new TrainingOptions.Builder()
585                         .setPopulationName(populationName2)
586                         .setJobSchedulerJobId(jobId2)
587                         .build();
588         mJobManager.onTrainerStartCalled(options2, new TestFederatedComputeCallback());
589 
590         long earliestNextRunTimeMillis = 2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS;
591         long minLatencyMillis = DEFAULT_SCHEDULING_PERIOD_MILLIS;
592         // If none of the job id, session name, population name and InAppTrainingConstraints
593         // changes,
594         // the previous earliest next
595         // run time will not change.
596         if (jobId1 == jobId2 && populationName1.equals(populationName2)) {
597             earliestNextRunTimeMillis = 1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS;
598         }
599         List<FederatedTrainingTask> taskList =
600                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
601         FederatedTrainingTask expectedTask =
602                 basicFLTrainingTaskBuilder(jobId2, populationName2, null)
603                         .earliestNextRunTime(earliestNextRunTimeMillis)
604                         .lastScheduledTime(2000L)
605                         .creationTime(populationName1.equals(populationName2) ? 1000L : 2000L)
606                         .constraints(DEFAULT_CONSTRAINTS)
607                         .schedulingReason(expectedSchedulingReason)
608                         .build();
609         assertThat(taskList).containsExactly(expectedTask);
610 
611         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
612         assertJobInfosMatch(
613                 mJobScheduler.getPendingJob(jobId2),
614                 buildExpectedJobInfo(jobId2, minLatencyMillis));
615     }
616 
617     @Test
testOnTrainingStarted_doesNotExist()618     public void testOnTrainingStarted_doesNotExist() throws Exception {
619         when(mClock.currentTimeMillis()).thenReturn(1000L);
620         FederatedTrainingTask taskToRun = mJobManager.onTrainingStarted(JOB_ID1);
621 
622         // No task should be found.
623         assertThat(taskToRun).isNull();
624         List<FederatedTrainingTask> taskList =
625                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
626         assertThat(taskList).isEmpty();
627     }
628 
629     @Test
testOnTrainingStarted_taskTtling_noTtlSet()630     public void testOnTrainingStarted_taskTtling_noTtlSet() throws Exception {
631         // Set task TTL to 0, which should disable TTLing.
632         when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(0L);
633 
634         long nowMillis = 1000;
635         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
636         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
637 
638         // Simulate attempting to run a task a lot later. This should not fail, b/c we're not yet
639         // past the TTL threshold.
640         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
641         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1);
642     }
643 
644     @Test
testOnTrainingStarted_taskTtling()645     public void testOnTrainingStarted_taskTtling() throws Exception {
646         // Set task TTL to 1 second.
647         when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(1L);
648 
649         when(mClock.currentTimeMillis()).thenReturn(1000L);
650         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
651 
652         // Simulate attempting to run a task one second later. This should not fail, b/c we're not
653         // yet
654         // past the TTL threshold.
655         long nowMillis = 2000;
656         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
657         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
658         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1);
659 
660         // Now reschedule again, should keep the task alive for another second.
661         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
662 
663         // The task should again still be alive a second later.
664         nowMillis = 3000;
665         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
666         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
667 
668         // Now move forward one millisecond. The task should now get TTLd.
669         nowMillis = 3001;
670         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
671         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull();
672         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty();
673     }
674 
675     @Test
testRescheduleFLTask_success()676     public void testRescheduleFLTask_success() throws Exception {
677         long nowMillis = 1000;
678         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
679         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
680 
681         nowMillis = 2000;
682         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
683         mJobManager.onTrainingStarted(JOB_ID1);
684 
685         nowMillis = 3000;
686         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
687         mJobManager.onTrainingCompleted(
688                 JOB_ID1,
689                 POPULATION_NAME1,
690                 createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0),
691                 TASK_RETRY,
692                 TrainingResult.SUCCESS);
693 
694         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull();
695         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1);
696     }
697 
698     @Test
testRescheduleFLTask_oneoff_success()699     public void testRescheduleFLTask_oneoff_success() throws Exception {
700         long nowMillis = 1000;
701         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
702         mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback());
703 
704         nowMillis = 2000;
705         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
706         mJobManager.onTrainingStarted(JOB_ID1);
707 
708         nowMillis = 3000;
709         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
710         mJobManager.onTrainingCompleted(
711                 JOB_ID1,
712                 POPULATION_NAME1,
713                 createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0),
714                 TASK_RETRY,
715                 TrainingResult.SUCCESS);
716 
717         assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull();
718         assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty();
719     }
720 
721     @Test
testRescheduleFLTask_didnotContribute_oneOff()722     public void testRescheduleFLTask_didnotContribute_oneOff() throws Exception {
723         long serverRetryDelayMillis = 5000_000;
724 
725         long nowMillis = 1000;
726         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
727         TrainingOptions trainerOptions =
728                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
729                         .setTrainingInterval(
730                                 new TrainingInterval.Builder()
731                                         .setSchedulingMode(
732                                                 TrainingInterval.SCHEDULING_MODE_ONE_TIME)
733                                         .build())
734                         .build();
735         mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback());
736 
737         nowMillis = 2000;
738         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
739         mJobManager.onTrainingStarted(JOB_ID1);
740 
741         nowMillis = 3000;
742         byte[] intervalOptions = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0);
743         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
744         mJobManager.onTrainingCompleted(
745                 JOB_ID1,
746                 POPULATION_NAME1,
747                 TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
748                         ByteBuffer.wrap(intervalOptions)),
749                 new TaskRetry.Builder()
750                         .setMinDelay(serverRetryDelayMillis)
751                         .setMaxDelay(serverRetryDelayMillis)
752                         .build(),
753                 TrainingResult.FAIL);
754 
755         List<FederatedTrainingTask> taskList =
756                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
757         FederatedTrainingTask expectedTask =
758                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions)
759                         .creationTime(1000L)
760                         .lastScheduledTime(1000L)
761                         .lastRunStartTime(2000L)
762                         .lastRunEndTime(3000L)
763                         .schedulingReason(
764                                 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY)
765                         .earliestNextRunTime(3000 + serverRetryDelayMillis)
766                         .build();
767         assertThat(taskList).containsExactly(expectedTask);
768 
769         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
770         assertJobInfosMatch(
771                 mJobScheduler.getPendingJob(JOB_ID1),
772                 buildExpectedJobInfo(JOB_ID1, serverRetryDelayMillis));
773     }
774 
775     /** Reschedule a recurrent fl task with the user defined interval. */
776     @Test
testRescheduleFLTask_success_recurrent_userDefinedInterval()777     public void testRescheduleFLTask_success_recurrent_userDefinedInterval() throws Exception {
778         // The user defined interval is larger than the server specified interval.
779         long minRetryDelayMillis = 3000_000;
780         long maxRetryDelayMillis = 3000_000;
781         long userDefinedIntervalMillis = 4000_000;
782         TrainingOptions trainerOptions =
783                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
784                         .setTrainingInterval(
785                                 new TrainingInterval.Builder()
786                                         .setSchedulingMode(
787                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
788                                         .setMinimumIntervalMillis(userDefinedIntervalMillis)
789                                         .build())
790                         .build();
791         long nowMillis = 1000;
792         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
793         mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback());
794 
795         nowMillis = 2000;
796         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
797         mJobManager.onTrainingStarted(JOB_ID1);
798 
799         nowMillis = 3000;
800         byte[] intervalOptions =
801                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis);
802         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
803         mJobManager.onTrainingCompleted(
804                 JOB_ID1,
805                 POPULATION_NAME1,
806                 TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
807                         ByteBuffer.wrap(intervalOptions)),
808                 new TaskRetry.Builder()
809                         .setMinDelay(minRetryDelayMillis)
810                         .setMaxDelay(maxRetryDelayMillis)
811                         .build(),
812                 TrainingResult.SUCCESS);
813 
814         List<FederatedTrainingTask> taskList =
815                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
816         FederatedTrainingTask expectedTask =
817                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions)
818                         .creationTime(1000L)
819                         .lastScheduledTime(1000L)
820                         .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted()
821                         .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted()
822                         .schedulingReason(
823                                 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY)
824                         .earliestNextRunTime(3000 + userDefinedIntervalMillis)
825                         .build();
826         assertThat(taskList).containsExactly(expectedTask);
827 
828         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
829         assertJobInfosMatch(
830                 mJobScheduler.getPendingJob(JOB_ID1),
831                 buildExpectedJobInfo(JOB_ID1, userDefinedIntervalMillis));
832     }
833 
834     @Test
testRescheduleFLTask_recurrent_serverDefinedInterval()835     public void testRescheduleFLTask_recurrent_serverDefinedInterval() throws Exception {
836         // Define a server returned interval which is larger than the user defined interval
837         long serverDefinedIntervalMillis = 4000_000;
838         long userDefinedIntervalMillis = 3000_000;
839 
840         TrainingOptions trainerOptions =
841                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
842                         .setTrainingInterval(
843                                 new TrainingInterval.Builder()
844                                         .setSchedulingMode(
845                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
846                                         .setMinimumIntervalMillis(userDefinedIntervalMillis)
847                                         .build())
848                         .build();
849 
850         long nowMillis = 1000;
851         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
852         mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback());
853 
854         nowMillis = 2000;
855         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
856         mJobManager.onTrainingStarted(JOB_ID1);
857 
858         nowMillis = 3000;
859         byte[] intervalOptions =
860                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis);
861         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
862         mJobManager.onTrainingCompleted(
863                 JOB_ID1,
864                 POPULATION_NAME1,
865                 TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
866                         ByteBuffer.wrap(intervalOptions)),
867                 new TaskRetry.Builder()
868                         .setMinDelay(serverDefinedIntervalMillis)
869                         .setMaxDelay(serverDefinedIntervalMillis)
870                         .build(),
871                 TrainingResult.SUCCESS);
872 
873         List<FederatedTrainingTask> taskList =
874                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
875         FederatedTrainingTask expectedTask =
876                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions)
877                         .creationTime(1000L)
878                         .lastScheduledTime(1000L)
879                         .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted()
880                         .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted()
881                         .schedulingReason(
882                                 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY)
883                         .earliestNextRunTime(3000 + serverDefinedIntervalMillis)
884                         .build();
885         assertThat(taskList).containsExactly(expectedTask);
886 
887         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
888         assertJobInfosMatch(
889                 mJobScheduler.getPendingJob(JOB_ID1),
890                 buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis));
891     }
892 
893     @Test
testRescheduleFLTask_recurrent_didnotContribute()894     public void testRescheduleFLTask_recurrent_didnotContribute() throws Exception {
895         // Define a server returned interval which is larger than the user defined interval
896         long serverDefinedIntervalMillis = 4000_000;
897         long userDefinedIntervalMillis = 3000_000;
898 
899         TrainingOptions trainerOptions =
900                 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1)
901                         .setTrainingInterval(
902                                 new TrainingInterval.Builder()
903                                         .setSchedulingMode(
904                                                 TrainingInterval.SCHEDULING_MODE_RECURRENT)
905                                         .setMinimumIntervalMillis(userDefinedIntervalMillis)
906                                         .build())
907                         .build();
908 
909         long nowMillis = 1000;
910         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
911         mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback());
912 
913         nowMillis = 2000;
914         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
915         mJobManager.onTrainingStarted(JOB_ID1);
916 
917         nowMillis = 3000;
918         byte[] intervalOptions =
919                 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis);
920         when(mClock.currentTimeMillis()).thenReturn(nowMillis);
921         mJobManager.onTrainingCompleted(
922                 JOB_ID1,
923                 POPULATION_NAME1,
924                 TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
925                         ByteBuffer.wrap(intervalOptions)),
926                 new TaskRetry.Builder()
927                         .setMinDelay(serverDefinedIntervalMillis)
928                         .setMaxDelay(serverDefinedIntervalMillis)
929                         .build(),
930                 TrainingResult.FAIL);
931 
932         List<FederatedTrainingTask> taskList =
933                 mTrainingTaskDao.getFederatedTrainingTask(null, null);
934         FederatedTrainingTask expectedTask =
935                 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions)
936                         .creationTime(1000L)
937                         .lastScheduledTime(1000L)
938                         .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted()
939                         .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted()
940                         .schedulingReason(
941                                 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY)
942                         .earliestNextRunTime(3000 + serverDefinedIntervalMillis)
943                         .build();
944         assertThat(taskList).containsExactly(expectedTask);
945 
946         assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1);
947         assertJobInfosMatch(
948                 mJobScheduler.getPendingJob(JOB_ID1),
949                 buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis));
950     }
951 
952     /**
953      * Helper for checking that two JobInfos match, since JobInfos unfortunately can't be compared
954      * directly.
955      */
assertJobInfosMatch(JobInfo pendingJob, JobInfo expectedJobInfo)956     public static void assertJobInfosMatch(JobInfo pendingJob, JobInfo expectedJobInfo) {
957         // Compare most of JobInfo's properties that may be set by our code.
958         assertWithMessage("id").that(pendingJob.getId()).isEqualTo(expectedJobInfo.getId());
959         assertWithMessage("service")
960                 .that(pendingJob.getService())
961                 .isEqualTo(expectedJobInfo.getService());
962         assertWithMessage("persisted")
963                 .that(pendingJob.isPersisted())
964                 .isEqualTo(expectedJobInfo.isPersisted());
965         assertWithMessage("networkType")
966                 .that(pendingJob.getNetworkType())
967                 .isEqualTo(expectedJobInfo.getNetworkType());
968         assertWithMessage("requireDeviceIdle")
969                 .that(pendingJob.isRequireDeviceIdle())
970                 .isEqualTo(expectedJobInfo.isRequireDeviceIdle());
971         assertWithMessage("requireCharging")
972                 .that(pendingJob.isRequireCharging())
973                 .isEqualTo(expectedJobInfo.isRequireCharging());
974         assertWithMessage("minLatencyMillis")
975                 .that(pendingJob.getMinLatencyMillis())
976                 .isEqualTo(expectedJobInfo.getMinLatencyMillis());
977         assertWithMessage("maxExecutionDelayMillis")
978                 .that(pendingJob.getMaxExecutionDelayMillis())
979                 .isEqualTo(expectedJobInfo.getMaxExecutionDelayMillis());
980     }
981 
basicFLOptionsBuilder(int jobId, String population)982     private static TrainingOptions.Builder basicFLOptionsBuilder(int jobId, String population) {
983         return new TrainingOptions.Builder()
984                 .setPopulationName(population)
985                 .setJobSchedulerJobId(jobId);
986     }
987 
buildExpectedJobInfo(int jobId, long minLatencyMillis)988     private JobInfo buildExpectedJobInfo(int jobId, long minLatencyMillis) {
989         JobInfo.Builder jobInfo =
990                 new JobInfo.Builder(
991                                 jobId,
992                                 new ComponentName(mContext.getPackageName(), TRAINING_JOB_SERVICE))
993                         .setPersisted(true)
994                         .setRequiresDeviceIdle(true)
995                         // the latency should be capped.
996                         .setMinimumLatency(minLatencyMillis)
997                         .setRequiresCharging(true);
998         jobInfo.setRequiredNetworkType(JobInfo.NETWORK_TYPE_UNMETERED);
999 
1000         return jobInfo.build();
1001     }
1002 
basicFLTrainingTaskBuilder( int jobId, String population, @Nullable byte[] trainingIntervalOptions)1003     private FederatedTrainingTask.Builder basicFLTrainingTaskBuilder(
1004             int jobId, String population, @Nullable byte[] trainingIntervalOptions) {
1005         FederatedTrainingTask.Builder builder =
1006                 FederatedTrainingTask.builder()
1007                         .jobId(jobId)
1008                         .populationName(population)
1009                         .lastScheduledTime(0L)
1010                         .lastRunStartTime(0L)
1011                         .lastRunEndTime(0L)
1012                         .constraints(DEFAULT_CONSTRAINTS)
1013                         .appPackageName(mContext.getPackageName());
1014         if (trainingIntervalOptions != null) {
1015             builder.intervalOptions(trainingIntervalOptions);
1016         }
1017         return builder;
1018     }
1019 
createTrainingIntervalOptionsAsRoot( int schedulingMode, long intervalMillis)1020     private static TrainingIntervalOptions createTrainingIntervalOptionsAsRoot(
1021             int schedulingMode, long intervalMillis) {
1022         byte[] intervalOptions = createTrainingIntervalOptions(schedulingMode, intervalMillis);
1023         return TrainingIntervalOptions.getRootAsTrainingIntervalOptions(
1024                 ByteBuffer.wrap(intervalOptions));
1025     }
1026 
createTrainingIntervalOptions(int schedulingMode, long intervalMillis)1027     private static byte[] createTrainingIntervalOptions(int schedulingMode, long intervalMillis) {
1028         FlatBufferBuilder builder = new FlatBufferBuilder();
1029         builder.finish(
1030                 TrainingIntervalOptions.createTrainingIntervalOptions(
1031                         builder, schedulingMode, intervalMillis));
1032         return builder.sizedByteArray();
1033     }
1034 
createDefaultTrainingConstraints()1035     private static byte[] createDefaultTrainingConstraints() {
1036         FlatBufferBuilder builder = new FlatBufferBuilder();
1037         builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true));
1038         return builder.sizedByteArray();
1039     }
1040 
1041     class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub {
1042         @Override
onSuccess()1043         public void onSuccess() {
1044             mSuccess = true;
1045             mLatch.countDown();
1046         }
1047 
1048         @Override
onFailure(int errorCode)1049         public void onFailure(int errorCode) {
1050             mLatch.countDown();
1051         }
1052     }
1053 }
1054