1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.scheduling; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static com.google.common.truth.Truth.assertWithMessage; 21 22 import static org.mockito.Mockito.when; 23 24 import static java.lang.Math.min; 25 26 import android.app.job.JobInfo; 27 import android.app.job.JobScheduler; 28 import android.content.ComponentName; 29 import android.content.Context; 30 import android.federatedcompute.aidl.IFederatedComputeCallback; 31 import android.federatedcompute.common.TrainingInterval; 32 import android.federatedcompute.common.TrainingOptions; 33 34 import androidx.test.core.app.ApplicationProvider; 35 36 import com.android.federatedcompute.services.common.Clock; 37 import com.android.federatedcompute.services.common.Flags; 38 import com.android.federatedcompute.services.common.TaskRetry; 39 import com.android.federatedcompute.services.common.TrainingResult; 40 import com.android.federatedcompute.services.data.FederatedTrainingTask; 41 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao; 42 import com.android.federatedcompute.services.data.FederatedTrainingTaskDbHelper; 43 import com.android.federatedcompute.services.data.fbs.SchedulingMode; 44 import com.android.federatedcompute.services.data.fbs.SchedulingReason; 45 import com.android.federatedcompute.services.data.fbs.TrainingConstraints; 46 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; 47 48 import com.google.flatbuffers.FlatBufferBuilder; 49 50 import org.junit.After; 51 import org.junit.Before; 52 import org.junit.Test; 53 import org.junit.runner.RunWith; 54 import org.mockito.Mock; 55 import org.mockito.junit.MockitoJUnitRunner; 56 57 import java.nio.ByteBuffer; 58 import java.util.List; 59 import java.util.concurrent.CountDownLatch; 60 61 import javax.annotation.Nullable; 62 63 @RunWith(MockitoJUnitRunner.class) 64 public final class FederatedComputeJobManagerTest { 65 private static final String POPULATION_NAME1 = "population1"; 66 private static final String POPULATION_NAME2 = "population2"; 67 private static final int JOB_ID1 = 700000001; 68 private static final int JOB_ID2 = 700000002; 69 private static final long DEFAULT_SCHEDULING_PERIOD_SECS = 1234; 70 private static final long DEFAULT_SCHEDULING_PERIOD_MILLIS = 71 DEFAULT_SCHEDULING_PERIOD_SECS * 1000; 72 private static final long MAX_SCHEDULING_PERIOD_SECS = 912000; 73 private static final long MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION = 604800L; 74 75 private static final String TRAINING_JOB_SERVICE = 76 "com.android.federatedcompute.services.training.FederatedJobService"; 77 private static final long CURRENT_TIME_MILLIS = 1000L; 78 private static final byte[] DEFAULT_CONSTRAINTS = createDefaultTrainingConstraints(); 79 private static final TrainingOptions OPTIONS1 = 80 new TrainingOptions.Builder() 81 .setPopulationName(POPULATION_NAME1) 82 .setJobSchedulerJobId(JOB_ID1) 83 .build(); 84 private static final TrainingOptions OPTIONS2 = 85 new TrainingOptions.Builder() 86 .setPopulationName(POPULATION_NAME2) 87 .setJobSchedulerJobId(JOB_ID2) 88 .build(); 89 private static final TaskRetry TASK_RETRY = 90 new TaskRetry.Builder().setMinDelay(5000000).setMaxDelay(6000000).build(); 91 92 private FederatedComputeJobManager mJobManager; 93 private Context mContext; 94 private FederatedTrainingTaskDao mTrainingTaskDao; 95 private boolean mSuccess = false; 96 private final CountDownLatch mLatch = new CountDownLatch(1); 97 98 @Mock private Clock mClock; 99 @Mock private Flags mMockFlags; 100 private JobScheduler mJobScheduler; 101 102 @Before setUp()103 public void setUp() { 104 mContext = ApplicationProvider.getApplicationContext(); 105 mJobScheduler = mContext.getSystemService(JobScheduler.class); 106 mJobScheduler.cancelAll(); 107 mTrainingTaskDao = FederatedTrainingTaskDao.getInstanceForTest(mContext); 108 mJobManager = 109 new FederatedComputeJobManager( 110 mContext, 111 mTrainingTaskDao, 112 new JobSchedulerHelper(mClock), 113 mClock, 114 mMockFlags); 115 when(mClock.currentTimeMillis()).thenReturn(CURRENT_TIME_MILLIS); 116 when(mMockFlags.getDefaultSchedulingPeriodSecs()) 117 .thenReturn(DEFAULT_SCHEDULING_PERIOD_SECS); 118 when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()) 119 .thenReturn(MAX_SCHEDULING_INTERVAL_SECS_FOR_FEDERATED_COMPUTATION); 120 when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation()).thenReturn(1L); 121 when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(MAX_SCHEDULING_PERIOD_SECS); 122 } 123 124 @After tearDown()125 public void tearDown() { 126 // Manually clean up the database. 127 mTrainingTaskDao.clearDatabase(); 128 FederatedTrainingTaskDbHelper dbHelper = 129 FederatedTrainingTaskDbHelper.getInstanceForTest(mContext); 130 dbHelper.getWritableDatabase().close(); 131 dbHelper.getReadableDatabase().close(); 132 dbHelper.close(); 133 } 134 135 @Test testOnTrainerStartCalledSuccess()136 public void testOnTrainerStartCalledSuccess() throws Exception { 137 when(mClock.currentTimeMillis()).thenReturn(1000L).thenReturn(2000L); 138 139 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 140 141 assertThat(mSuccess).isTrue(); 142 List<FederatedTrainingTask> taskList = 143 mTrainingTaskDao.getFederatedTrainingTask(null, null); 144 assertThat(taskList) 145 .containsExactly( 146 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) 147 .creationTime(1000L) 148 .lastScheduledTime(1000L) 149 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 150 .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) 151 .build()); 152 } 153 154 @Test testOnTrainerStartCalled_firstTime()155 public void testOnTrainerStartCalled_firstTime() throws Exception { 156 when(mClock.currentTimeMillis()).thenReturn(1000L); 157 // Make three onTrainerStart calls, each with different job ID and session name. 158 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 159 when(mClock.currentTimeMillis()).thenReturn(2000L); 160 mJobManager.onTrainerStartCalled(OPTIONS2, new TestFederatedComputeCallback()); 161 mLatch.await(); 162 163 assertThat(mSuccess).isTrue(); 164 // verify training tasks in database. 165 List<FederatedTrainingTask> taskList = 166 mTrainingTaskDao.getFederatedTrainingTask(null, null); 167 assertThat(taskList) 168 .containsExactly( 169 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) 170 .creationTime(1000L) 171 .lastScheduledTime(1000L) 172 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 173 .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) 174 .build(), 175 basicFLTrainingTaskBuilder(JOB_ID2, POPULATION_NAME2, null) 176 .creationTime(2000L) 177 .lastScheduledTime(2000L) 178 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 179 .earliestNextRunTime(2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) 180 .build()); 181 182 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(2); 183 assertJobInfosMatch( 184 mJobScheduler.getPendingJob(JOB_ID1), 185 buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); 186 assertJobInfosMatch( 187 mJobScheduler.getPendingJob(JOB_ID2), 188 buildExpectedJobInfo(JOB_ID2, DEFAULT_SCHEDULING_PERIOD_MILLIS)); 189 } 190 191 @Test testOnTrainerStartCalledFL_withIntervalSmallerThanDefaultInterval()192 public void testOnTrainerStartCalledFL_withIntervalSmallerThanDefaultInterval() 193 throws Exception { 194 testOnTrainerStartCalledFLWithInterval( 195 /* userDefinedIntervalMillis= */ 1000000, /* defaultIntervalMillis= */ 2000000); 196 } 197 198 @Test testOnTrainerStartCalledFL_withIntervalLargerThanDefaultInterval()199 public void testOnTrainerStartCalledFL_withIntervalLargerThanDefaultInterval() 200 throws Exception { 201 testOnTrainerStartCalledFLWithInterval( 202 /* userDefinedIntervalMillis= */ 2000000, /* defaultIntervalMillis= */ 1000000); 203 } 204 testOnTrainerStartCalledFLWithInterval( long userDefinedIntervalMillis, long defaultIntervalMillis)205 private void testOnTrainerStartCalledFLWithInterval( 206 long userDefinedIntervalMillis, long defaultIntervalMillis) throws Exception { 207 when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(defaultIntervalMillis / 1000); 208 TrainingOptions trainerOptions = 209 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 210 .setTrainingInterval( 211 new TrainingInterval.Builder() 212 .setSchedulingMode( 213 TrainingInterval.SCHEDULING_MODE_RECURRENT) 214 .setMinimumIntervalMillis(userDefinedIntervalMillis) 215 .build()) 216 .build(); 217 mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback()); 218 219 byte[] trainingIntervalOptions = 220 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); 221 222 long expectedInterval = min(userDefinedIntervalMillis, defaultIntervalMillis); 223 FederatedTrainingTask expectedTask = 224 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions) 225 .earliestNextRunTime(CURRENT_TIME_MILLIS + expectedInterval) 226 .lastScheduledTime(CURRENT_TIME_MILLIS) 227 .creationTime(CURRENT_TIME_MILLIS) 228 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 229 .build(); 230 List<FederatedTrainingTask> taskList = 231 mTrainingTaskDao.getFederatedTrainingTask(null, null); 232 233 assertThat(taskList).containsExactly(expectedTask); 234 assertJobInfosMatch( 235 mJobScheduler.getPendingJob(JOB_ID1), 236 buildExpectedJobInfo(JOB_ID1, expectedInterval)); 237 } 238 239 /** 240 * Tests onTrainerStart being called multiple times with the same parameters (the common 241 * expected use case). 242 * 243 * <p>After the first call, most fields in the task (like creation time, earliest next run time, 244 * etc.) must be preserved, and only certain fields (like last scheduled time) should be 245 * updated. 246 */ 247 @Test testOnTrainerStartCalled_multipleTimes_sameParams()248 public void testOnTrainerStartCalled_multipleTimes_sameParams() throws Exception { 249 when(mClock.currentTimeMillis()).thenReturn(1000L); 250 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 251 252 when(mClock.currentTimeMillis()).thenReturn(2000L); 253 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 254 255 when(mClock.currentTimeMillis()).thenReturn(3000L); 256 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 257 258 List<FederatedTrainingTask> taskList = 259 mTrainingTaskDao.getFederatedTrainingTask(null, null); 260 FederatedTrainingTask expectedTask = 261 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, null) 262 .earliestNextRunTime(1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS) 263 .lastScheduledTime(3000L) 264 .creationTime(1000L) 265 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 266 .build(); 267 assertThat(taskList).containsExactly(expectedTask); 268 269 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 270 assertJobInfosMatch( 271 mJobScheduler.getPendingJob(JOB_ID1), 272 buildExpectedJobInfo(JOB_ID1, DEFAULT_SCHEDULING_PERIOD_MILLIS)); 273 } 274 275 /** 276 * Tests when the user specified interval is larger than the maximum server specified interval, 277 * multiple scheduling with same user specified interval will not be incorrectly capped at the 278 * maximum server specified interval. 279 */ 280 @Test testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalLargerThanServerMax()281 public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalLargerThanServerMax() 282 throws Exception { 283 long minIntervalMills = 10000L; // 10 seconds 284 // Maximum server specified interval is 5 seconds 285 when(mMockFlags.getMaxSchedulingPeriodSecs()).thenReturn(5L); 286 TrainingOptions trainingOptions = 287 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 288 .setTrainingInterval( 289 new TrainingInterval.Builder() 290 .setSchedulingMode( 291 TrainingInterval.SCHEDULING_MODE_RECURRENT) 292 .setMinimumIntervalMillis(minIntervalMills) 293 .build()) 294 .build(); 295 296 when(mClock.currentTimeMillis()).thenReturn(1000L); 297 mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback()); 298 299 when(mClock.currentTimeMillis()).thenReturn(2000L); 300 mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback()); 301 302 when(mClock.currentTimeMillis()).thenReturn(3000L); 303 mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback()); 304 305 List<FederatedTrainingTask> taskList = 306 mTrainingTaskDao.getFederatedTrainingTask(null, null); 307 byte[] expectedInterval = 308 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills); 309 FederatedTrainingTask expectedTask = 310 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) 311 .earliestNextRunTime(1000 + minIntervalMills) 312 .lastScheduledTime(3000L) 313 .creationTime(1000L) 314 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 315 .build(); 316 assertThat(taskList).containsExactly(expectedTask); 317 318 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 319 assertJobInfosMatch( 320 mJobScheduler.getPendingJob(JOB_ID1), 321 buildExpectedJobInfo(JOB_ID1, minIntervalMills)); 322 } 323 324 /** 325 * Tests when a task got scheduled with the same set of parameters multiple times, the brella 326 * defined max for user specified interval has been lowered between the multiple scheduling 327 * events, the user specified interval should be always guarded with the latest max. 328 */ 329 @Test testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalDifferentMax()330 public void testOnTrainerStartCalled_multipleTimes_sameParamsFLWithIntervalDifferentMax() 331 throws Exception { 332 // Initial max 20 seconds is larger than the user specified interval. 333 when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()).thenReturn(20L); 334 long minIntervalMills = 10000L; // 10 seconds 335 TrainingOptions trainingOptions = 336 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 337 .setTrainingInterval( 338 new TrainingInterval.Builder() 339 .setSchedulingMode( 340 TrainingInterval.SCHEDULING_MODE_RECURRENT) 341 .setMinimumIntervalMillis(minIntervalMills) 342 .build()) 343 .build(); 344 345 when(mClock.currentTimeMillis()).thenReturn(1000L); 346 mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback()); 347 348 List<FederatedTrainingTask> taskList = 349 mTrainingTaskDao.getFederatedTrainingTask(null, null); 350 byte[] expectedInterval = 351 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minIntervalMills); 352 FederatedTrainingTask expectedTask = 353 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) 354 .earliestNextRunTime(1000L + minIntervalMills) 355 .lastScheduledTime(1000L) 356 .creationTime(1000L) 357 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 358 .build(); 359 assertThat(taskList).containsExactly(expectedTask); 360 361 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 362 assertJobInfosMatch( 363 mJobScheduler.getPendingJob(JOB_ID1), 364 buildExpectedJobInfo(JOB_ID1, minIntervalMills)); 365 366 // Now lower allowed max for the user specified interval 367 long newMaxSec = 5L; 368 long newMinIntervalMills = newMaxSec * 1000; 369 when(mMockFlags.getMaxSchedulingIntervalSecsForFederatedComputation()) 370 .thenReturn(newMaxSec); 371 TrainingOptions newTrainingOptions = 372 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 373 .setTrainingInterval( 374 new TrainingInterval.Builder() 375 .setSchedulingMode( 376 TrainingInterval.SCHEDULING_MODE_RECURRENT) 377 .setMinimumIntervalMillis(newMinIntervalMills) 378 .build()) 379 .build(); 380 381 when(mClock.currentTimeMillis()).thenReturn(2000L); 382 mJobManager.onTrainerStartCalled(newTrainingOptions, new TestFederatedComputeCallback()); 383 384 taskList = mTrainingTaskDao.getFederatedTrainingTask(null, null); 385 expectedInterval = 386 createTrainingIntervalOptions(SchedulingMode.RECURRENT, newMinIntervalMills); 387 expectedTask = 388 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) 389 .earliestNextRunTime(2000L + newMinIntervalMills) 390 .lastScheduledTime(2000L) 391 .creationTime(1000L) 392 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 393 .build(); 394 assertThat(taskList).containsExactly(expectedTask); 395 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 396 assertJobInfosMatch( 397 mJobScheduler.getPendingJob(JOB_ID1), 398 buildExpectedJobInfo(JOB_ID1, newMinIntervalMills)); 399 } 400 401 @Test testOnTrainerStartCalled_fLCustomerSpecifiedIntervalSmallerThanDefinedMin()402 public void testOnTrainerStartCalled_fLCustomerSpecifiedIntervalSmallerThanDefinedMin() 403 throws Exception { 404 when(mMockFlags.getDefaultSchedulingPeriodSecs()).thenReturn(2000L); 405 long minTrainingIntervalSecByFederatedCompute = 1800L; 406 long minTrainingIntervalMillsByFederatedCompute = 407 minTrainingIntervalSecByFederatedCompute * 1000; 408 when(mMockFlags.getMinSchedulingIntervalSecsForFederatedComputation()) 409 .thenReturn(minTrainingIntervalSecByFederatedCompute); 410 411 TrainingOptions trainingOptions = 412 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 413 .setTrainingInterval( 414 new TrainingInterval.Builder() 415 .setSchedulingMode( 416 TrainingInterval.SCHEDULING_MODE_RECURRENT) 417 .setMinimumIntervalMillis(1000L) 418 .build()) 419 .build(); 420 421 when(mClock.currentTimeMillis()).thenReturn(1000L); 422 mJobManager.onTrainerStartCalled(trainingOptions, new TestFederatedComputeCallback()); 423 424 List<FederatedTrainingTask> taskList = 425 mTrainingTaskDao.getFederatedTrainingTask(null, null); 426 byte[] expectedInterval = createTrainingIntervalOptions(SchedulingMode.RECURRENT, 1000L); 427 FederatedTrainingTask expectedTask = 428 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, expectedInterval) 429 .earliestNextRunTime(1000L + minTrainingIntervalMillsByFederatedCompute) 430 .lastScheduledTime(1000L) 431 .creationTime(1000L) 432 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 433 .build(); 434 assertThat(taskList).containsExactly(expectedTask); 435 436 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 437 assertJobInfosMatch( 438 mJobScheduler.getPendingJob(JOB_ID1), 439 buildExpectedJobInfo(JOB_ID1, minTrainingIntervalMillsByFederatedCompute)); 440 } 441 442 @Test testOnTrainerStartCalled_trainingIntervalChange_FL()443 public void testOnTrainerStartCalled_trainingIntervalChange_FL() throws Exception { 444 when(mClock.currentTimeMillis()).thenReturn(1000L); 445 mJobManager.onTrainerStartCalled( 446 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(), 447 new TestFederatedComputeCallback()); 448 449 long minTrainingIntervalMillis = 60000L; 450 when(mClock.currentTimeMillis()).thenReturn(2000L); 451 mJobManager.onTrainerStartCalled( 452 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 453 .setTrainingInterval( 454 new TrainingInterval.Builder() 455 .setSchedulingMode( 456 TrainingInterval.SCHEDULING_MODE_RECURRENT) 457 .setMinimumIntervalMillis(minTrainingIntervalMillis) 458 .build()) 459 .build(), 460 new TestFederatedComputeCallback()); 461 byte[] trainingInterval = 462 createTrainingIntervalOptions(SchedulingMode.RECURRENT, minTrainingIntervalMillis); 463 verifyTaskAndJobAfterIntervalChange( 464 trainingInterval, 1000, 2000, minTrainingIntervalMillis); 465 466 long newInterval = 70000L; 467 when(mClock.currentTimeMillis()).thenReturn(3000L); 468 mJobManager.onTrainerStartCalled( 469 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 470 .setTrainingInterval( 471 new TrainingInterval.Builder() 472 .setSchedulingMode( 473 TrainingInterval.SCHEDULING_MODE_RECURRENT) 474 .setMinimumIntervalMillis(newInterval) 475 .build()) 476 .build(), 477 new TestFederatedComputeCallback()); 478 byte[] trainingIntervalOption2 = 479 createTrainingIntervalOptions(SchedulingMode.RECURRENT, newInterval); 480 // Verify the creation time not changed, modified time is set to now, and the min interval 481 // is set to the new interval. 482 verifyTaskAndJobAfterIntervalChange(trainingIntervalOption2, 1000, 3000, newInterval); 483 484 when(mClock.currentTimeMillis()).thenReturn(4000L); 485 mJobManager.onTrainerStartCalled( 486 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 487 .setTrainingInterval( 488 new TrainingInterval.Builder() 489 .setSchedulingMode( 490 TrainingInterval.SCHEDULING_MODE_ONE_TIME) 491 .build()) 492 .build(), 493 new TestFederatedComputeCallback()); 494 byte[] trainingIntervalOption3 = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0L); 495 // Verify the creation time not changed, modified time is set to now, and the min interval 496 // is set to the new interval. 497 verifyTaskAndJobAfterIntervalChange( 498 trainingIntervalOption3, 1000, 4000, DEFAULT_SCHEDULING_PERIOD_MILLIS); 499 500 // Transition back to not set 501 when(mClock.currentTimeMillis()).thenReturn(5000L); 502 mJobManager.onTrainerStartCalled( 503 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1).build(), 504 new TestFederatedComputeCallback()); 505 verifyTaskAndJobAfterIntervalChange(null, 1000, 5000, DEFAULT_SCHEDULING_PERIOD_MILLIS); 506 } 507 verifyTaskAndJobAfterIntervalChange( @ullable byte[] trainingIntervalOptions, long createTimeMillis, long modifyTimeMillis, long expectedIntervalMillis)508 private void verifyTaskAndJobAfterIntervalChange( 509 @Nullable byte[] trainingIntervalOptions, 510 long createTimeMillis, 511 long modifyTimeMillis, 512 long expectedIntervalMillis) 513 throws Exception { 514 List<FederatedTrainingTask> taskList = 515 mTrainingTaskDao.getFederatedTrainingTask(null, null); 516 FederatedTrainingTask expectedTask = 517 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, trainingIntervalOptions) 518 .earliestNextRunTime(modifyTimeMillis + expectedIntervalMillis) 519 .lastScheduledTime(modifyTimeMillis) 520 .creationTime(createTimeMillis) 521 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 522 .build(); 523 assertThat(taskList).containsExactly(expectedTask); 524 525 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 526 assertJobInfosMatch( 527 mJobScheduler.getPendingJob(JOB_ID1), 528 buildExpectedJobInfo(JOB_ID1, expectedIntervalMillis)); 529 } 530 531 @Test testOnTrainerStartCalled_multipleTimes_changingPopulationName()532 public void testOnTrainerStartCalled_multipleTimes_changingPopulationName() throws Exception { 533 // Only change the population name. 534 int jobId = JOB_ID1; 535 doTestOnTrainerStartCalled_multipleTimes_changingParams( 536 jobId, 537 POPULATION_NAME1, 538 jobId, 539 POPULATION_NAME2, 540 SchedulingReason.SCHEDULING_REASON_NEW_TASK); 541 } 542 543 @Test testOnTrainerStartCalled_twoJobsWithSamePopulationName()544 public void testOnTrainerStartCalled_twoJobsWithSamePopulationName() throws Exception { 545 // Change both the job ID and session name between Trainer.start calls. 546 doTestOnTrainerStartCalled_multipleTimes_changingParams( 547 JOB_ID1, 548 POPULATION_NAME1, 549 JOB_ID2, 550 POPULATION_NAME1, 551 SchedulingReason.SCHEDULING_REASON_NEW_TASK); 552 } 553 554 @Test testOnTrainerStartCalled_multipleTimes_changingJobId()555 public void testOnTrainerStartCalled_multipleTimes_changingJobId() throws Exception { 556 // Only change the job ID. 557 String populationName = POPULATION_NAME1; 558 doTestOnTrainerStartCalled_multipleTimes_changingParams( 559 JOB_ID1, 560 populationName, 561 JOB_ID2, 562 populationName, 563 SchedulingReason.SCHEDULING_REASON_NEW_TASK); 564 } 565 doTestOnTrainerStartCalled_multipleTimes_changingParams( int jobId1, String populationName1, int jobId2, String populationName2, int expectedSchedulingReason)566 private void doTestOnTrainerStartCalled_multipleTimes_changingParams( 567 int jobId1, 568 String populationName1, 569 int jobId2, 570 String populationName2, 571 int expectedSchedulingReason) 572 throws Exception { 573 when(mClock.currentTimeMillis()).thenReturn(1000L); 574 TrainingOptions options1 = 575 new TrainingOptions.Builder() 576 .setPopulationName(populationName1) 577 .setJobSchedulerJobId(jobId1) 578 .build(); 579 mJobManager.onTrainerStartCalled(options1, new TestFederatedComputeCallback()); 580 581 // Pass in a new population name. 582 when(mClock.currentTimeMillis()).thenReturn(2000L); 583 TrainingOptions options2 = 584 new TrainingOptions.Builder() 585 .setPopulationName(populationName2) 586 .setJobSchedulerJobId(jobId2) 587 .build(); 588 mJobManager.onTrainerStartCalled(options2, new TestFederatedComputeCallback()); 589 590 long earliestNextRunTimeMillis = 2000 + DEFAULT_SCHEDULING_PERIOD_MILLIS; 591 long minLatencyMillis = DEFAULT_SCHEDULING_PERIOD_MILLIS; 592 // If none of the job id, session name, population name and InAppTrainingConstraints 593 // changes, 594 // the previous earliest next 595 // run time will not change. 596 if (jobId1 == jobId2 && populationName1.equals(populationName2)) { 597 earliestNextRunTimeMillis = 1000 + DEFAULT_SCHEDULING_PERIOD_MILLIS; 598 } 599 List<FederatedTrainingTask> taskList = 600 mTrainingTaskDao.getFederatedTrainingTask(null, null); 601 FederatedTrainingTask expectedTask = 602 basicFLTrainingTaskBuilder(jobId2, populationName2, null) 603 .earliestNextRunTime(earliestNextRunTimeMillis) 604 .lastScheduledTime(2000L) 605 .creationTime(populationName1.equals(populationName2) ? 1000L : 2000L) 606 .constraints(DEFAULT_CONSTRAINTS) 607 .schedulingReason(expectedSchedulingReason) 608 .build(); 609 assertThat(taskList).containsExactly(expectedTask); 610 611 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 612 assertJobInfosMatch( 613 mJobScheduler.getPendingJob(jobId2), 614 buildExpectedJobInfo(jobId2, minLatencyMillis)); 615 } 616 617 @Test testOnTrainingStarted_doesNotExist()618 public void testOnTrainingStarted_doesNotExist() throws Exception { 619 when(mClock.currentTimeMillis()).thenReturn(1000L); 620 FederatedTrainingTask taskToRun = mJobManager.onTrainingStarted(JOB_ID1); 621 622 // No task should be found. 623 assertThat(taskToRun).isNull(); 624 List<FederatedTrainingTask> taskList = 625 mTrainingTaskDao.getFederatedTrainingTask(null, null); 626 assertThat(taskList).isEmpty(); 627 } 628 629 @Test testOnTrainingStarted_taskTtling_noTtlSet()630 public void testOnTrainingStarted_taskTtling_noTtlSet() throws Exception { 631 // Set task TTL to 0, which should disable TTLing. 632 when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(0L); 633 634 long nowMillis = 1000; 635 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 636 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 637 638 // Simulate attempting to run a task a lot later. This should not fail, b/c we're not yet 639 // past the TTL threshold. 640 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); 641 assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); 642 } 643 644 @Test testOnTrainingStarted_taskTtling()645 public void testOnTrainingStarted_taskTtling() throws Exception { 646 // Set task TTL to 1 second. 647 when(mMockFlags.getTrainingTimeForLiveSeconds()).thenReturn(1L); 648 649 when(mClock.currentTimeMillis()).thenReturn(1000L); 650 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 651 652 // Simulate attempting to run a task one second later. This should not fail, b/c we're not 653 // yet 654 // past the TTL threshold. 655 long nowMillis = 2000; 656 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 657 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); 658 assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); 659 660 // Now reschedule again, should keep the task alive for another second. 661 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 662 663 // The task should again still be alive a second later. 664 nowMillis = 3000; 665 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 666 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); 667 668 // Now move forward one millisecond. The task should now get TTLd. 669 nowMillis = 3001; 670 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 671 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull(); 672 assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); 673 } 674 675 @Test testRescheduleFLTask_success()676 public void testRescheduleFLTask_success() throws Exception { 677 long nowMillis = 1000; 678 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 679 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 680 681 nowMillis = 2000; 682 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 683 mJobManager.onTrainingStarted(JOB_ID1); 684 685 nowMillis = 3000; 686 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 687 mJobManager.onTrainingCompleted( 688 JOB_ID1, 689 POPULATION_NAME1, 690 createTrainingIntervalOptionsAsRoot(SchedulingMode.RECURRENT, 0), 691 TASK_RETRY, 692 TrainingResult.SUCCESS); 693 694 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNotNull(); 695 assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).hasSize(1); 696 } 697 698 @Test testRescheduleFLTask_oneoff_success()699 public void testRescheduleFLTask_oneoff_success() throws Exception { 700 long nowMillis = 1000; 701 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 702 mJobManager.onTrainerStartCalled(OPTIONS1, new TestFederatedComputeCallback()); 703 704 nowMillis = 2000; 705 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 706 mJobManager.onTrainingStarted(JOB_ID1); 707 708 nowMillis = 3000; 709 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 710 mJobManager.onTrainingCompleted( 711 JOB_ID1, 712 POPULATION_NAME1, 713 createTrainingIntervalOptionsAsRoot(SchedulingMode.ONE_TIME, 0), 714 TASK_RETRY, 715 TrainingResult.SUCCESS); 716 717 assertThat(mJobManager.onTrainingStarted(JOB_ID1)).isNull(); 718 assertThat(mTrainingTaskDao.getFederatedTrainingTask(null, null)).isEmpty(); 719 } 720 721 @Test testRescheduleFLTask_didnotContribute_oneOff()722 public void testRescheduleFLTask_didnotContribute_oneOff() throws Exception { 723 long serverRetryDelayMillis = 5000_000; 724 725 long nowMillis = 1000; 726 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 727 TrainingOptions trainerOptions = 728 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 729 .setTrainingInterval( 730 new TrainingInterval.Builder() 731 .setSchedulingMode( 732 TrainingInterval.SCHEDULING_MODE_ONE_TIME) 733 .build()) 734 .build(); 735 mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback()); 736 737 nowMillis = 2000; 738 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 739 mJobManager.onTrainingStarted(JOB_ID1); 740 741 nowMillis = 3000; 742 byte[] intervalOptions = createTrainingIntervalOptions(SchedulingMode.ONE_TIME, 0); 743 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 744 mJobManager.onTrainingCompleted( 745 JOB_ID1, 746 POPULATION_NAME1, 747 TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 748 ByteBuffer.wrap(intervalOptions)), 749 new TaskRetry.Builder() 750 .setMinDelay(serverRetryDelayMillis) 751 .setMaxDelay(serverRetryDelayMillis) 752 .build(), 753 TrainingResult.FAIL); 754 755 List<FederatedTrainingTask> taskList = 756 mTrainingTaskDao.getFederatedTrainingTask(null, null); 757 FederatedTrainingTask expectedTask = 758 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) 759 .creationTime(1000L) 760 .lastScheduledTime(1000L) 761 .lastRunStartTime(2000L) 762 .lastRunEndTime(3000L) 763 .schedulingReason( 764 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) 765 .earliestNextRunTime(3000 + serverRetryDelayMillis) 766 .build(); 767 assertThat(taskList).containsExactly(expectedTask); 768 769 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 770 assertJobInfosMatch( 771 mJobScheduler.getPendingJob(JOB_ID1), 772 buildExpectedJobInfo(JOB_ID1, serverRetryDelayMillis)); 773 } 774 775 /** Reschedule a recurrent fl task with the user defined interval. */ 776 @Test testRescheduleFLTask_success_recurrent_userDefinedInterval()777 public void testRescheduleFLTask_success_recurrent_userDefinedInterval() throws Exception { 778 // The user defined interval is larger than the server specified interval. 779 long minRetryDelayMillis = 3000_000; 780 long maxRetryDelayMillis = 3000_000; 781 long userDefinedIntervalMillis = 4000_000; 782 TrainingOptions trainerOptions = 783 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 784 .setTrainingInterval( 785 new TrainingInterval.Builder() 786 .setSchedulingMode( 787 TrainingInterval.SCHEDULING_MODE_RECURRENT) 788 .setMinimumIntervalMillis(userDefinedIntervalMillis) 789 .build()) 790 .build(); 791 long nowMillis = 1000; 792 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 793 mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback()); 794 795 nowMillis = 2000; 796 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 797 mJobManager.onTrainingStarted(JOB_ID1); 798 799 nowMillis = 3000; 800 byte[] intervalOptions = 801 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); 802 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 803 mJobManager.onTrainingCompleted( 804 JOB_ID1, 805 POPULATION_NAME1, 806 TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 807 ByteBuffer.wrap(intervalOptions)), 808 new TaskRetry.Builder() 809 .setMinDelay(minRetryDelayMillis) 810 .setMaxDelay(maxRetryDelayMillis) 811 .build(), 812 TrainingResult.SUCCESS); 813 814 List<FederatedTrainingTask> taskList = 815 mTrainingTaskDao.getFederatedTrainingTask(null, null); 816 FederatedTrainingTask expectedTask = 817 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) 818 .creationTime(1000L) 819 .lastScheduledTime(1000L) 820 .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() 821 .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() 822 .schedulingReason( 823 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) 824 .earliestNextRunTime(3000 + userDefinedIntervalMillis) 825 .build(); 826 assertThat(taskList).containsExactly(expectedTask); 827 828 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 829 assertJobInfosMatch( 830 mJobScheduler.getPendingJob(JOB_ID1), 831 buildExpectedJobInfo(JOB_ID1, userDefinedIntervalMillis)); 832 } 833 834 @Test testRescheduleFLTask_recurrent_serverDefinedInterval()835 public void testRescheduleFLTask_recurrent_serverDefinedInterval() throws Exception { 836 // Define a server returned interval which is larger than the user defined interval 837 long serverDefinedIntervalMillis = 4000_000; 838 long userDefinedIntervalMillis = 3000_000; 839 840 TrainingOptions trainerOptions = 841 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 842 .setTrainingInterval( 843 new TrainingInterval.Builder() 844 .setSchedulingMode( 845 TrainingInterval.SCHEDULING_MODE_RECURRENT) 846 .setMinimumIntervalMillis(userDefinedIntervalMillis) 847 .build()) 848 .build(); 849 850 long nowMillis = 1000; 851 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 852 mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback()); 853 854 nowMillis = 2000; 855 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 856 mJobManager.onTrainingStarted(JOB_ID1); 857 858 nowMillis = 3000; 859 byte[] intervalOptions = 860 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); 861 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 862 mJobManager.onTrainingCompleted( 863 JOB_ID1, 864 POPULATION_NAME1, 865 TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 866 ByteBuffer.wrap(intervalOptions)), 867 new TaskRetry.Builder() 868 .setMinDelay(serverDefinedIntervalMillis) 869 .setMaxDelay(serverDefinedIntervalMillis) 870 .build(), 871 TrainingResult.SUCCESS); 872 873 List<FederatedTrainingTask> taskList = 874 mTrainingTaskDao.getFederatedTrainingTask(null, null); 875 FederatedTrainingTask expectedTask = 876 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) 877 .creationTime(1000L) 878 .lastScheduledTime(1000L) 879 .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() 880 .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() 881 .schedulingReason( 882 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) 883 .earliestNextRunTime(3000 + serverDefinedIntervalMillis) 884 .build(); 885 assertThat(taskList).containsExactly(expectedTask); 886 887 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 888 assertJobInfosMatch( 889 mJobScheduler.getPendingJob(JOB_ID1), 890 buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis)); 891 } 892 893 @Test testRescheduleFLTask_recurrent_didnotContribute()894 public void testRescheduleFLTask_recurrent_didnotContribute() throws Exception { 895 // Define a server returned interval which is larger than the user defined interval 896 long serverDefinedIntervalMillis = 4000_000; 897 long userDefinedIntervalMillis = 3000_000; 898 899 TrainingOptions trainerOptions = 900 basicFLOptionsBuilder(JOB_ID1, POPULATION_NAME1) 901 .setTrainingInterval( 902 new TrainingInterval.Builder() 903 .setSchedulingMode( 904 TrainingInterval.SCHEDULING_MODE_RECURRENT) 905 .setMinimumIntervalMillis(userDefinedIntervalMillis) 906 .build()) 907 .build(); 908 909 long nowMillis = 1000; 910 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 911 mJobManager.onTrainerStartCalled(trainerOptions, new TestFederatedComputeCallback()); 912 913 nowMillis = 2000; 914 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 915 mJobManager.onTrainingStarted(JOB_ID1); 916 917 nowMillis = 3000; 918 byte[] intervalOptions = 919 createTrainingIntervalOptions(SchedulingMode.RECURRENT, userDefinedIntervalMillis); 920 when(mClock.currentTimeMillis()).thenReturn(nowMillis); 921 mJobManager.onTrainingCompleted( 922 JOB_ID1, 923 POPULATION_NAME1, 924 TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 925 ByteBuffer.wrap(intervalOptions)), 926 new TaskRetry.Builder() 927 .setMinDelay(serverDefinedIntervalMillis) 928 .setMaxDelay(serverDefinedIntervalMillis) 929 .build(), 930 TrainingResult.FAIL); 931 932 List<FederatedTrainingTask> taskList = 933 mTrainingTaskDao.getFederatedTrainingTask(null, null); 934 FederatedTrainingTask expectedTask = 935 basicFLTrainingTaskBuilder(JOB_ID1, POPULATION_NAME1, intervalOptions) 936 .creationTime(1000L) 937 .lastScheduledTime(1000L) 938 .lastRunStartTime(2000L) // Match the time of calling onTrainingStarted() 939 .lastRunEndTime(3000L) // Match the time of calling onTrainingCompleted() 940 .schedulingReason( 941 SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY) 942 .earliestNextRunTime(3000 + serverDefinedIntervalMillis) 943 .build(); 944 assertThat(taskList).containsExactly(expectedTask); 945 946 assertThat(mJobScheduler.getAllPendingJobs()).hasSize(1); 947 assertJobInfosMatch( 948 mJobScheduler.getPendingJob(JOB_ID1), 949 buildExpectedJobInfo(JOB_ID1, serverDefinedIntervalMillis)); 950 } 951 952 /** 953 * Helper for checking that two JobInfos match, since JobInfos unfortunately can't be compared 954 * directly. 955 */ assertJobInfosMatch(JobInfo pendingJob, JobInfo expectedJobInfo)956 public static void assertJobInfosMatch(JobInfo pendingJob, JobInfo expectedJobInfo) { 957 // Compare most of JobInfo's properties that may be set by our code. 958 assertWithMessage("id").that(pendingJob.getId()).isEqualTo(expectedJobInfo.getId()); 959 assertWithMessage("service") 960 .that(pendingJob.getService()) 961 .isEqualTo(expectedJobInfo.getService()); 962 assertWithMessage("persisted") 963 .that(pendingJob.isPersisted()) 964 .isEqualTo(expectedJobInfo.isPersisted()); 965 assertWithMessage("networkType") 966 .that(pendingJob.getNetworkType()) 967 .isEqualTo(expectedJobInfo.getNetworkType()); 968 assertWithMessage("requireDeviceIdle") 969 .that(pendingJob.isRequireDeviceIdle()) 970 .isEqualTo(expectedJobInfo.isRequireDeviceIdle()); 971 assertWithMessage("requireCharging") 972 .that(pendingJob.isRequireCharging()) 973 .isEqualTo(expectedJobInfo.isRequireCharging()); 974 assertWithMessage("minLatencyMillis") 975 .that(pendingJob.getMinLatencyMillis()) 976 .isEqualTo(expectedJobInfo.getMinLatencyMillis()); 977 assertWithMessage("maxExecutionDelayMillis") 978 .that(pendingJob.getMaxExecutionDelayMillis()) 979 .isEqualTo(expectedJobInfo.getMaxExecutionDelayMillis()); 980 } 981 basicFLOptionsBuilder(int jobId, String population)982 private static TrainingOptions.Builder basicFLOptionsBuilder(int jobId, String population) { 983 return new TrainingOptions.Builder() 984 .setPopulationName(population) 985 .setJobSchedulerJobId(jobId); 986 } 987 buildExpectedJobInfo(int jobId, long minLatencyMillis)988 private JobInfo buildExpectedJobInfo(int jobId, long minLatencyMillis) { 989 JobInfo.Builder jobInfo = 990 new JobInfo.Builder( 991 jobId, 992 new ComponentName(mContext.getPackageName(), TRAINING_JOB_SERVICE)) 993 .setPersisted(true) 994 .setRequiresDeviceIdle(true) 995 // the latency should be capped. 996 .setMinimumLatency(minLatencyMillis) 997 .setRequiresCharging(true); 998 jobInfo.setRequiredNetworkType(JobInfo.NETWORK_TYPE_UNMETERED); 999 1000 return jobInfo.build(); 1001 } 1002 basicFLTrainingTaskBuilder( int jobId, String population, @Nullable byte[] trainingIntervalOptions)1003 private FederatedTrainingTask.Builder basicFLTrainingTaskBuilder( 1004 int jobId, String population, @Nullable byte[] trainingIntervalOptions) { 1005 FederatedTrainingTask.Builder builder = 1006 FederatedTrainingTask.builder() 1007 .jobId(jobId) 1008 .populationName(population) 1009 .lastScheduledTime(0L) 1010 .lastRunStartTime(0L) 1011 .lastRunEndTime(0L) 1012 .constraints(DEFAULT_CONSTRAINTS) 1013 .appPackageName(mContext.getPackageName()); 1014 if (trainingIntervalOptions != null) { 1015 builder.intervalOptions(trainingIntervalOptions); 1016 } 1017 return builder; 1018 } 1019 createTrainingIntervalOptionsAsRoot( int schedulingMode, long intervalMillis)1020 private static TrainingIntervalOptions createTrainingIntervalOptionsAsRoot( 1021 int schedulingMode, long intervalMillis) { 1022 byte[] intervalOptions = createTrainingIntervalOptions(schedulingMode, intervalMillis); 1023 return TrainingIntervalOptions.getRootAsTrainingIntervalOptions( 1024 ByteBuffer.wrap(intervalOptions)); 1025 } 1026 createTrainingIntervalOptions(int schedulingMode, long intervalMillis)1027 private static byte[] createTrainingIntervalOptions(int schedulingMode, long intervalMillis) { 1028 FlatBufferBuilder builder = new FlatBufferBuilder(); 1029 builder.finish( 1030 TrainingIntervalOptions.createTrainingIntervalOptions( 1031 builder, schedulingMode, intervalMillis)); 1032 return builder.sizedByteArray(); 1033 } 1034 createDefaultTrainingConstraints()1035 private static byte[] createDefaultTrainingConstraints() { 1036 FlatBufferBuilder builder = new FlatBufferBuilder(); 1037 builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true)); 1038 return builder.sizedByteArray(); 1039 } 1040 1041 class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub { 1042 @Override onSuccess()1043 public void onSuccess() { 1044 mSuccess = true; 1045 mLatch.countDown(); 1046 } 1047 1048 @Override onFailure(int errorCode)1049 public void onFailure(int errorCode) { 1050 mLatch.countDown(); 1051 } 1052 } 1053 } 1054