1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.scheduling; 18 19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; 20 21 import static com.android.federatedcompute.services.scheduling.SchedulingUtil.convertSchedulingMode; 22 23 import static java.util.concurrent.TimeUnit.SECONDS; 24 25 import android.annotation.NonNull; 26 import android.annotation.Nullable; 27 import android.content.Context; 28 import android.federatedcompute.aidl.IFederatedComputeCallback; 29 import android.federatedcompute.common.TrainingInterval; 30 import android.federatedcompute.common.TrainingOptions; 31 import android.os.RemoteException; 32 import android.util.Log; 33 34 import com.android.federatedcompute.services.common.Clock; 35 import com.android.federatedcompute.services.common.Flags; 36 import com.android.federatedcompute.services.common.MonotonicClock; 37 import com.android.federatedcompute.services.common.PhFlags; 38 import com.android.federatedcompute.services.common.TaskRetry; 39 import com.android.federatedcompute.services.common.TrainingResult; 40 import com.android.federatedcompute.services.data.FederatedTrainingTask; 41 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao; 42 import com.android.federatedcompute.services.data.fbs.SchedulingMode; 43 import com.android.federatedcompute.services.data.fbs.SchedulingReason; 44 import com.android.federatedcompute.services.data.fbs.TrainingConstraints; 45 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; 46 47 import com.google.common.annotations.VisibleForTesting; 48 import com.google.flatbuffers.FlatBufferBuilder; 49 50 import java.util.Arrays; 51 import java.util.HashSet; 52 import java.util.Set; 53 54 /** Handles scheduling training tasks e.g. calling into JobScheduler, maintaining datastore. */ 55 public class FederatedComputeJobManager { 56 private static final String TAG = "FederatedComputeJobManager"; 57 58 @NonNull private final Context mContext; 59 private final FederatedTrainingTaskDao mFederatedTrainingTaskDao; 60 private final JobSchedulerHelper mJobSchedulerHelper; 61 private static FederatedComputeJobManager sSingletonInstance; 62 private Clock mClock; 63 private final Flags mFlags; 64 65 @VisibleForTesting FederatedComputeJobManager( @onNull Context context, FederatedTrainingTaskDao federatedTrainingTaskDao, JobSchedulerHelper jobSchedulerHelper, @NonNull Clock clock, Flags flag)66 FederatedComputeJobManager( 67 @NonNull Context context, 68 FederatedTrainingTaskDao federatedTrainingTaskDao, 69 JobSchedulerHelper jobSchedulerHelper, 70 @NonNull Clock clock, 71 Flags flag) { 72 this.mContext = context; 73 this.mFederatedTrainingTaskDao = federatedTrainingTaskDao; 74 this.mJobSchedulerHelper = jobSchedulerHelper; 75 this.mClock = clock; 76 this.mFlags = flag; 77 } 78 79 /** Returns an instance of FederatedComputeJobManager given a context. */ 80 @NonNull getInstance(@onNull Context mContext)81 public static FederatedComputeJobManager getInstance(@NonNull Context mContext) { 82 synchronized (FederatedComputeJobManager.class) { 83 if (sSingletonInstance == null) { 84 Clock clock = MonotonicClock.getInstance(); 85 sSingletonInstance = 86 new FederatedComputeJobManager( 87 mContext, 88 FederatedTrainingTaskDao.getInstance(mContext), 89 new JobSchedulerHelper(clock), 90 clock, 91 PhFlags.getInstance()); 92 } 93 return sSingletonInstance; 94 } 95 } 96 /** 97 * Called when a client indicates via the client API that a task with the given parameters 98 * should be scheduled. 99 */ onTrainerStartCalled( TrainingOptions trainingOptions, IFederatedComputeCallback callback)100 public synchronized void onTrainerStartCalled( 101 TrainingOptions trainingOptions, IFederatedComputeCallback callback) { 102 FederatedTrainingTask existingTask = 103 mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationName( 104 trainingOptions.getPopulationName()); 105 Set<FederatedTrainingTask> trainingTasksToCancel = new HashSet<>(); 106 // If another task with same jobId exists, we only need to delete it and don't need cancel 107 // the task because we will overwrite it anyway. 108 mFederatedTrainingTaskDao.findAndRemoveTaskByJobId(trainingOptions.getJobSchedulerJobId()); 109 long nowMs = mClock.currentTimeMillis(); 110 int jobId = trainingOptions.getJobSchedulerJobId(); 111 boolean shouldSchedule = false; 112 FederatedTrainingTask newTask; 113 byte[] newTrainingConstraint = buildTrainingConstraints(); 114 115 if (existingTask == null) { 116 FederatedTrainingTask.Builder newTaskBuilder = 117 FederatedTrainingTask.builder() 118 .appPackageName(mContext.getPackageName()) 119 .jobId(jobId) 120 .creationTime(nowMs) 121 .lastScheduledTime(nowMs) 122 .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK) 123 .constraints(newTrainingConstraint) 124 .populationName(trainingOptions.getPopulationName()) 125 .earliestNextRunTime( 126 SchedulingUtil.getEarliestRuntimeForInitialSchedule( 127 nowMs, 0, trainingOptions, mFlags)); 128 if (trainingOptions.getTrainingInterval() != null) { 129 newTaskBuilder.intervalOptions( 130 buildTrainingIntervalOptions(trainingOptions.getTrainingInterval())); 131 } 132 newTask = newTaskBuilder.build(); 133 shouldSchedule = true; 134 } else { 135 // If a task does exist already then update only those fields that should be 136 // updated: job ID, population name, constraints, last scheduled time, BUT maintain 137 // other important fields like the earliest next run time unless the population or job 138 // ID has changed. This ensures that repeated calls to onTrainerStart do not keep 139 // postponing the job's next runtime. 140 FederatedTrainingTask.Builder newTaskBuilder = 141 existingTask.toBuilder() 142 .jobId(jobId) 143 .constraints(buildTrainingConstraints()) 144 .lastScheduledTime(nowMs); 145 if (detectKeyParametersChanged(trainingOptions, existingTask, trainingTasksToCancel)) { 146 newTaskBuilder.intervalOptions(null).lastRunStartTime(null).lastRunEndTime(null); 147 newTaskBuilder 148 .populationName(trainingOptions.getPopulationName()) 149 .earliestNextRunTime( 150 SchedulingUtil.getEarliestRuntimeForInitialSchedule( 151 nowMs, nowMs, trainingOptions, mFlags)); 152 if (trainingOptions.getTrainingInterval() != null) { 153 newTaskBuilder.intervalOptions( 154 buildTrainingIntervalOptions(trainingOptions.getTrainingInterval())); 155 } 156 shouldSchedule = true; 157 } else { 158 long earliestNextRunTime = 159 SchedulingUtil.getEarliestRuntimeForExistingTask( 160 existingTask, trainingOptions, mFlags, nowMs); 161 long maxExpectedRuntimeSecs = 162 mFlags.getTrainingServiceResultCallbackTimeoutSecs() + /*buffer*/ 30; 163 boolean currentlyRunningHeuristic = 164 existingTask.lastRunStartTime() < nowMs 165 && nowMs - existingTask.lastRunStartTime() 166 < 1000 * maxExpectedRuntimeSecs 167 && existingTask.lastRunStartTime() > existingTask.lastRunEndTime(); 168 shouldSchedule = 169 !currentlyRunningHeuristic 170 && (!mJobSchedulerHelper.isTaskScheduled(mContext, existingTask) 171 || !Arrays.equals( 172 existingTask.constraints(), newTrainingConstraint) 173 || !existingTask 174 .earliestNextRunTime() 175 .equals(earliestNextRunTime)); 176 177 // If we have to reschedule, update the earliest next run time. Otherwise, 178 // retain the original earliest next run time. 179 newTaskBuilder.earliestNextRunTime( 180 shouldSchedule ? earliestNextRunTime : existingTask.earliestNextRunTime()); 181 } 182 // If we have to reschedule, mark this task as "new"; otherwise, retain the original 183 // reason for scheduling it. 184 newTaskBuilder.schedulingReason( 185 shouldSchedule 186 ? SchedulingReason.SCHEDULING_REASON_NEW_TASK 187 : existingTask.schedulingReason()); 188 newTask = newTaskBuilder.build(); 189 } 190 191 // Now reconcile the new task store and JobScheduler. 192 // 193 // First, if necessary, try to (re)schedule the task. 194 if (shouldSchedule) { 195 boolean scheduleResult = mJobSchedulerHelper.scheduleTask(mContext, newTask); 196 if (!scheduleResult) { 197 Log.w( 198 TAG, 199 "JobScheduler returned failure when starting training job " 200 + newTask.jobId()); 201 // If scheduling failed then leave the task store as-is, and bail. 202 sendError(callback); 203 return; 204 } 205 } 206 207 // Add the new task into federated training task store. if failed, return the error. 208 boolean storeResult = 209 mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask); 210 if (!storeResult) { 211 Log.w( 212 TAG, 213 "JobScheduler returned failure when storing training job!" + newTask.jobId()); 214 sendError(callback); 215 return; 216 } 217 // Second, if the task previously had a different job ID or a if there was another 218 // task with the same population name, then cancel the corresponding old tasks. 219 for (FederatedTrainingTask trainingTaskToCancel : trainingTasksToCancel) { 220 Log.i(TAG, " JobScheduler cancel the task " + newTask.jobId()); 221 mJobSchedulerHelper.cancelTask(mContext, trainingTaskToCancel); 222 } 223 sendSuccess(callback); 224 } 225 226 /** Called when a training task identified by {@code jobId} starts running. */ 227 @Nullable 228 public synchronized FederatedTrainingTask onTrainingStarted(int jobId) { 229 FederatedTrainingTask existingTask = 230 mFederatedTrainingTaskDao.findAndRemoveTaskByJobId(jobId); 231 if (existingTask == null) { 232 return null; 233 } 234 long ttlMs = SECONDS.toMillis(mFlags.getTrainingTimeForLiveSeconds()); 235 long nowMs = mClock.currentTimeMillis(); 236 if (ttlMs > 0 && nowMs - existingTask.lastScheduledTime() > ttlMs) { 237 // If the TTL is expired, then delete the task. 238 Log.i(TAG, String.format("Training task %d TTLd", jobId)); 239 return null; 240 } 241 FederatedTrainingTask newTask = existingTask.toBuilder().lastRunStartTime(nowMs).build(); 242 mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask); 243 return newTask; 244 } 245 246 /** Called when a training task completed. */ onTrainingCompleted( int jobId, String populationName, TrainingIntervalOptions trainingIntervalOptions, TaskRetry taskRetry, @TrainingResult int trainingResult)247 public synchronized void onTrainingCompleted( 248 int jobId, 249 String populationName, 250 TrainingIntervalOptions trainingIntervalOptions, 251 TaskRetry taskRetry, 252 @TrainingResult int trainingResult) { 253 boolean result = 254 rescheduleFederatedTaskAfterTraining( 255 jobId, populationName, trainingIntervalOptions, taskRetry, trainingResult); 256 if (!result) { 257 Log.e(TAG, "JobScheduler returned failure after successful run!"); 258 } 259 } 260 261 /** Tries to reschedule a federated task after a failed or successful training run. */ rescheduleFederatedTaskAfterTraining( int jobId, String populationName, TrainingIntervalOptions intervalOptions, TaskRetry taskRetry, @TrainingResult int trainingResult)262 private synchronized boolean rescheduleFederatedTaskAfterTraining( 263 int jobId, 264 String populationName, 265 TrainingIntervalOptions intervalOptions, 266 TaskRetry taskRetry, 267 @TrainingResult int trainingResult) { 268 FederatedTrainingTask existingTask = 269 mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationAndJobId( 270 populationName, jobId); 271 // If task was deleted already, then return early, but still consider it a success 272 // since this is not really an error case (e.g. Trainer.stop may have simply been 273 // called while training was running). 274 if (existingTask == null) { 275 return true; 276 } 277 boolean hasContributed = trainingResult == TrainingResult.SUCCESS; 278 if (intervalOptions != null 279 && intervalOptions.schedulingMode() == SchedulingMode.ONE_TIME 280 && hasContributed) { 281 mJobSchedulerHelper.cancelTask(mContext, existingTask); 282 Log.i(TAG, "federated task remove because oneoff task succeeded: " + jobId); 283 return true; 284 } 285 // Update the task and add it back to the training task store. 286 long nowMillis = mClock.currentTimeMillis(); 287 long earliestNextRunTime = 288 SchedulingUtil.getEarliestRuntimeForFCReschedule( 289 nowMillis, intervalOptions, taskRetry, hasContributed, mFlags); 290 FederatedTrainingTask.Builder newTaskBuilder = 291 existingTask.toBuilder() 292 .lastRunEndTime(nowMillis) 293 .earliestNextRunTime(earliestNextRunTime); 294 newTaskBuilder.schedulingReason( 295 taskRetry != null 296 ? SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY 297 : SchedulingReason.SCHEDULING_REASON_FAILURE); 298 FederatedTrainingTask newTask = newTaskBuilder.build(); 299 mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask); 300 return mJobSchedulerHelper.scheduleTask(mContext, newTask); 301 } 302 buildTrainingConstraints()303 private static byte[] buildTrainingConstraints() { 304 FlatBufferBuilder builder = new FlatBufferBuilder(); 305 builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true)); 306 return builder.sizedByteArray(); 307 } 308 buildTrainingIntervalOptions( @ullable TrainingInterval trainingInterval)309 private static byte[] buildTrainingIntervalOptions( 310 @Nullable TrainingInterval trainingInterval) { 311 FlatBufferBuilder builder = new FlatBufferBuilder(); 312 if (trainingInterval == null) { 313 builder.finish( 314 TrainingIntervalOptions.createTrainingIntervalOptions( 315 builder, SchedulingMode.ONE_TIME, 0)); 316 return builder.sizedByteArray(); 317 } 318 builder.finish( 319 TrainingIntervalOptions.createTrainingIntervalOptions( 320 builder, 321 convertSchedulingMode(trainingInterval.getSchedulingMode()), 322 trainingInterval.getMinimumIntervalMillis())); 323 324 return builder.sizedByteArray(); 325 } 326 detectKeyParametersChanged( TrainingOptions newTaskOptions, FederatedTrainingTask existingTask, Set<FederatedTrainingTask> trainingTasksToCancel)327 private boolean detectKeyParametersChanged( 328 TrainingOptions newTaskOptions, 329 FederatedTrainingTask existingTask, 330 Set<FederatedTrainingTask> trainingTasksToCancel) { 331 // Check if the task previously had a different JobScheduler job ID. If it did then 332 // cancel that job for that old ID so it's not left hanging. 333 boolean jobIdChanged = existingTask.jobId() != newTaskOptions.getJobSchedulerJobId(); 334 if (jobIdChanged) { 335 Log.i( 336 TAG, 337 String.format( 338 "JobScheduler job id changed from %d to %d", 339 existingTask.jobId(), newTaskOptions.getJobSchedulerJobId())); 340 trainingTasksToCancel.add(existingTask); 341 } 342 343 // Check if the task previously had a different population name. 344 boolean populationChanged = 345 !existingTask.populationName().equals(newTaskOptions.getPopulationName()); 346 if (populationChanged) { 347 Log.i( 348 TAG, 349 String.format( 350 "JobScheduler population name changed from %s to %s", 351 existingTask.populationName(), newTaskOptions.getPopulationName())); 352 } 353 354 boolean trainingIntervalChanged = trainingIntervalChanged(newTaskOptions, existingTask); 355 if (trainingIntervalChanged) { 356 Log.i( 357 TAG, 358 String.format( 359 "JobScheduler training interval changed from %s to %s", 360 existingTask.getTrainingIntervalOptions(), 361 newTaskOptions.getTrainingInterval())); 362 } 363 return jobIdChanged || populationChanged || trainingIntervalChanged; 364 } 365 trainingIntervalChanged( TrainingOptions newTaskOptions, FederatedTrainingTask existingTask)366 private static boolean trainingIntervalChanged( 367 TrainingOptions newTaskOptions, FederatedTrainingTask existingTask) { 368 byte[] incomingTrainingIntervalOptions = 369 newTaskOptions.getTrainingInterval() == null 370 ? null 371 : buildTrainingIntervalOptions(newTaskOptions.getTrainingInterval()); 372 return !Arrays.equals(incomingTrainingIntervalOptions, existingTask.intervalOptions()); 373 } 374 sendError(@onNull IFederatedComputeCallback callback)375 private void sendError(@NonNull IFederatedComputeCallback callback) { 376 try { 377 callback.onFailure(STATUS_INTERNAL_ERROR); 378 } catch (RemoteException e) { 379 Log.e(TAG, "IFederatedComputeCallback error", e); 380 } 381 } 382 sendSuccess(@onNull IFederatedComputeCallback callback)383 private void sendSuccess(@NonNull IFederatedComputeCallback callback) { 384 try { 385 callback.onSuccess(); 386 } catch (RemoteException e) { 387 Log.e(TAG, "IFederatedComputeCallback error", e); 388 } 389 } 390 } 391