• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.scheduling;
18 
19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR;
20 
21 import static com.android.federatedcompute.services.scheduling.SchedulingUtil.convertSchedulingMode;
22 
23 import static java.util.concurrent.TimeUnit.SECONDS;
24 
25 import android.annotation.NonNull;
26 import android.annotation.Nullable;
27 import android.content.Context;
28 import android.federatedcompute.aidl.IFederatedComputeCallback;
29 import android.federatedcompute.common.TrainingInterval;
30 import android.federatedcompute.common.TrainingOptions;
31 import android.os.RemoteException;
32 import android.util.Log;
33 
34 import com.android.federatedcompute.services.common.Clock;
35 import com.android.federatedcompute.services.common.Flags;
36 import com.android.federatedcompute.services.common.MonotonicClock;
37 import com.android.federatedcompute.services.common.PhFlags;
38 import com.android.federatedcompute.services.common.TaskRetry;
39 import com.android.federatedcompute.services.common.TrainingResult;
40 import com.android.federatedcompute.services.data.FederatedTrainingTask;
41 import com.android.federatedcompute.services.data.FederatedTrainingTaskDao;
42 import com.android.federatedcompute.services.data.fbs.SchedulingMode;
43 import com.android.federatedcompute.services.data.fbs.SchedulingReason;
44 import com.android.federatedcompute.services.data.fbs.TrainingConstraints;
45 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions;
46 
47 import com.google.common.annotations.VisibleForTesting;
48 import com.google.flatbuffers.FlatBufferBuilder;
49 
50 import java.util.Arrays;
51 import java.util.HashSet;
52 import java.util.Set;
53 
54 /** Handles scheduling training tasks e.g. calling into JobScheduler, maintaining datastore. */
55 public class FederatedComputeJobManager {
56     private static final String TAG = "FederatedComputeJobManager";
57 
58     @NonNull private final Context mContext;
59     private final FederatedTrainingTaskDao mFederatedTrainingTaskDao;
60     private final JobSchedulerHelper mJobSchedulerHelper;
61     private static FederatedComputeJobManager sSingletonInstance;
62     private Clock mClock;
63     private final Flags mFlags;
64 
65     @VisibleForTesting
FederatedComputeJobManager( @onNull Context context, FederatedTrainingTaskDao federatedTrainingTaskDao, JobSchedulerHelper jobSchedulerHelper, @NonNull Clock clock, Flags flag)66     FederatedComputeJobManager(
67             @NonNull Context context,
68             FederatedTrainingTaskDao federatedTrainingTaskDao,
69             JobSchedulerHelper jobSchedulerHelper,
70             @NonNull Clock clock,
71             Flags flag) {
72         this.mContext = context;
73         this.mFederatedTrainingTaskDao = federatedTrainingTaskDao;
74         this.mJobSchedulerHelper = jobSchedulerHelper;
75         this.mClock = clock;
76         this.mFlags = flag;
77     }
78 
79     /** Returns an instance of FederatedComputeJobManager given a context. */
80     @NonNull
getInstance(@onNull Context mContext)81     public static FederatedComputeJobManager getInstance(@NonNull Context mContext) {
82         synchronized (FederatedComputeJobManager.class) {
83             if (sSingletonInstance == null) {
84                 Clock clock = MonotonicClock.getInstance();
85                 sSingletonInstance =
86                         new FederatedComputeJobManager(
87                                 mContext,
88                                 FederatedTrainingTaskDao.getInstance(mContext),
89                                 new JobSchedulerHelper(clock),
90                                 clock,
91                                 PhFlags.getInstance());
92             }
93             return sSingletonInstance;
94         }
95     }
96     /**
97      * Called when a client indicates via the client API that a task with the given parameters
98      * should be scheduled.
99      */
onTrainerStartCalled( TrainingOptions trainingOptions, IFederatedComputeCallback callback)100     public synchronized void onTrainerStartCalled(
101             TrainingOptions trainingOptions, IFederatedComputeCallback callback) {
102         FederatedTrainingTask existingTask =
103                 mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationName(
104                         trainingOptions.getPopulationName());
105         Set<FederatedTrainingTask> trainingTasksToCancel = new HashSet<>();
106         // If another task with same jobId exists, we only need to delete it and don't need cancel
107         // the task because we will overwrite it anyway.
108         mFederatedTrainingTaskDao.findAndRemoveTaskByJobId(trainingOptions.getJobSchedulerJobId());
109         long nowMs = mClock.currentTimeMillis();
110         int jobId = trainingOptions.getJobSchedulerJobId();
111         boolean shouldSchedule = false;
112         FederatedTrainingTask newTask;
113         byte[] newTrainingConstraint = buildTrainingConstraints();
114 
115         if (existingTask == null) {
116             FederatedTrainingTask.Builder newTaskBuilder =
117                     FederatedTrainingTask.builder()
118                             .appPackageName(mContext.getPackageName())
119                             .jobId(jobId)
120                             .creationTime(nowMs)
121                             .lastScheduledTime(nowMs)
122                             .schedulingReason(SchedulingReason.SCHEDULING_REASON_NEW_TASK)
123                             .constraints(newTrainingConstraint)
124                             .populationName(trainingOptions.getPopulationName())
125                             .earliestNextRunTime(
126                                     SchedulingUtil.getEarliestRuntimeForInitialSchedule(
127                                             nowMs, 0, trainingOptions, mFlags));
128             if (trainingOptions.getTrainingInterval() != null) {
129                 newTaskBuilder.intervalOptions(
130                         buildTrainingIntervalOptions(trainingOptions.getTrainingInterval()));
131             }
132             newTask = newTaskBuilder.build();
133             shouldSchedule = true;
134         } else {
135             // If a task does exist already then update only those fields that should be
136             // updated: job ID, population name, constraints, last scheduled time, BUT maintain
137             // other important fields like the earliest next run time unless the population or job
138             // ID has changed. This ensures that repeated calls to onTrainerStart do not keep
139             // postponing the job's next runtime.
140             FederatedTrainingTask.Builder newTaskBuilder =
141                     existingTask.toBuilder()
142                             .jobId(jobId)
143                             .constraints(buildTrainingConstraints())
144                             .lastScheduledTime(nowMs);
145             if (detectKeyParametersChanged(trainingOptions, existingTask, trainingTasksToCancel)) {
146                 newTaskBuilder.intervalOptions(null).lastRunStartTime(null).lastRunEndTime(null);
147                 newTaskBuilder
148                         .populationName(trainingOptions.getPopulationName())
149                         .earliestNextRunTime(
150                                 SchedulingUtil.getEarliestRuntimeForInitialSchedule(
151                                         nowMs, nowMs, trainingOptions, mFlags));
152                 if (trainingOptions.getTrainingInterval() != null) {
153                     newTaskBuilder.intervalOptions(
154                             buildTrainingIntervalOptions(trainingOptions.getTrainingInterval()));
155                 }
156                 shouldSchedule = true;
157             } else {
158                 long earliestNextRunTime =
159                         SchedulingUtil.getEarliestRuntimeForExistingTask(
160                                 existingTask, trainingOptions, mFlags, nowMs);
161                 long maxExpectedRuntimeSecs =
162                         mFlags.getTrainingServiceResultCallbackTimeoutSecs() + /*buffer*/ 30;
163                 boolean currentlyRunningHeuristic =
164                         existingTask.lastRunStartTime() < nowMs
165                                 && nowMs - existingTask.lastRunStartTime()
166                                         < 1000 * maxExpectedRuntimeSecs
167                                 && existingTask.lastRunStartTime() > existingTask.lastRunEndTime();
168                 shouldSchedule =
169                         !currentlyRunningHeuristic
170                                 && (!mJobSchedulerHelper.isTaskScheduled(mContext, existingTask)
171                                         || !Arrays.equals(
172                                                 existingTask.constraints(), newTrainingConstraint)
173                                         || !existingTask
174                                                 .earliestNextRunTime()
175                                                 .equals(earliestNextRunTime));
176 
177                 // If we have to reschedule, update the earliest next run time. Otherwise,
178                 // retain the original earliest next run time.
179                 newTaskBuilder.earliestNextRunTime(
180                         shouldSchedule ? earliestNextRunTime : existingTask.earliestNextRunTime());
181             }
182             // If we have to reschedule, mark this task as "new"; otherwise, retain the original
183             // reason for scheduling it.
184             newTaskBuilder.schedulingReason(
185                     shouldSchedule
186                             ? SchedulingReason.SCHEDULING_REASON_NEW_TASK
187                             : existingTask.schedulingReason());
188             newTask = newTaskBuilder.build();
189         }
190 
191         // Now reconcile the new task store and JobScheduler.
192         //
193         // First, if necessary, try to (re)schedule the task.
194         if (shouldSchedule) {
195             boolean scheduleResult = mJobSchedulerHelper.scheduleTask(mContext, newTask);
196             if (!scheduleResult) {
197                 Log.w(
198                         TAG,
199                         "JobScheduler returned failure when starting training job "
200                                 + newTask.jobId());
201                 // If scheduling failed then leave the task store as-is, and bail.
202                 sendError(callback);
203                 return;
204             }
205         }
206 
207         // Add the new task into federated training task store. if failed, return the error.
208         boolean storeResult =
209                 mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask);
210         if (!storeResult) {
211             Log.w(
212                     TAG,
213                     "JobScheduler returned failure when storing training job!" + newTask.jobId());
214             sendError(callback);
215             return;
216         }
217         // Second, if the task previously had a different job ID or a if there was another
218         // task with the same population name, then cancel the corresponding old tasks.
219         for (FederatedTrainingTask trainingTaskToCancel : trainingTasksToCancel) {
220             Log.i(TAG, " JobScheduler cancel the task " + newTask.jobId());
221             mJobSchedulerHelper.cancelTask(mContext, trainingTaskToCancel);
222         }
223         sendSuccess(callback);
224     }
225 
226     /** Called when a training task identified by {@code jobId} starts running. */
227     @Nullable
228     public synchronized FederatedTrainingTask onTrainingStarted(int jobId) {
229         FederatedTrainingTask existingTask =
230                 mFederatedTrainingTaskDao.findAndRemoveTaskByJobId(jobId);
231         if (existingTask == null) {
232             return null;
233         }
234         long ttlMs = SECONDS.toMillis(mFlags.getTrainingTimeForLiveSeconds());
235         long nowMs = mClock.currentTimeMillis();
236         if (ttlMs > 0 && nowMs - existingTask.lastScheduledTime() > ttlMs) {
237             // If the TTL is expired, then delete the task.
238             Log.i(TAG, String.format("Training task %d TTLd", jobId));
239             return null;
240         }
241         FederatedTrainingTask newTask = existingTask.toBuilder().lastRunStartTime(nowMs).build();
242         mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask);
243         return newTask;
244     }
245 
246     /** Called when a training task completed. */
onTrainingCompleted( int jobId, String populationName, TrainingIntervalOptions trainingIntervalOptions, TaskRetry taskRetry, @TrainingResult int trainingResult)247     public synchronized void onTrainingCompleted(
248             int jobId,
249             String populationName,
250             TrainingIntervalOptions trainingIntervalOptions,
251             TaskRetry taskRetry,
252             @TrainingResult int trainingResult) {
253         boolean result =
254                 rescheduleFederatedTaskAfterTraining(
255                         jobId, populationName, trainingIntervalOptions, taskRetry, trainingResult);
256         if (!result) {
257             Log.e(TAG, "JobScheduler returned failure after successful run!");
258         }
259     }
260 
261     /** Tries to reschedule a federated task after a failed or successful training run. */
rescheduleFederatedTaskAfterTraining( int jobId, String populationName, TrainingIntervalOptions intervalOptions, TaskRetry taskRetry, @TrainingResult int trainingResult)262     private synchronized boolean rescheduleFederatedTaskAfterTraining(
263             int jobId,
264             String populationName,
265             TrainingIntervalOptions intervalOptions,
266             TaskRetry taskRetry,
267             @TrainingResult int trainingResult) {
268         FederatedTrainingTask existingTask =
269                 mFederatedTrainingTaskDao.findAndRemoveTaskByPopulationAndJobId(
270                         populationName, jobId);
271         // If task was deleted already, then return early, but still consider it a success
272         // since this is not really an error case (e.g. Trainer.stop may have simply been
273         // called while training was running).
274         if (existingTask == null) {
275             return true;
276         }
277         boolean hasContributed = trainingResult == TrainingResult.SUCCESS;
278         if (intervalOptions != null
279                 && intervalOptions.schedulingMode() == SchedulingMode.ONE_TIME
280                 && hasContributed) {
281             mJobSchedulerHelper.cancelTask(mContext, existingTask);
282             Log.i(TAG, "federated task remove because oneoff task succeeded: " + jobId);
283             return true;
284         }
285         // Update the task and add it back to the training task store.
286         long nowMillis = mClock.currentTimeMillis();
287         long earliestNextRunTime =
288                 SchedulingUtil.getEarliestRuntimeForFCReschedule(
289                         nowMillis, intervalOptions, taskRetry, hasContributed, mFlags);
290         FederatedTrainingTask.Builder newTaskBuilder =
291                 existingTask.toBuilder()
292                         .lastRunEndTime(nowMillis)
293                         .earliestNextRunTime(earliestNextRunTime);
294         newTaskBuilder.schedulingReason(
295                 taskRetry != null
296                         ? SchedulingReason.SCHEDULING_REASON_FEDERATED_COMPUTATION_RETRY
297                         : SchedulingReason.SCHEDULING_REASON_FAILURE);
298         FederatedTrainingTask newTask = newTaskBuilder.build();
299         mFederatedTrainingTaskDao.updateOrInsertFederatedTrainingTask(newTask);
300         return mJobSchedulerHelper.scheduleTask(mContext, newTask);
301     }
302 
buildTrainingConstraints()303     private static byte[] buildTrainingConstraints() {
304         FlatBufferBuilder builder = new FlatBufferBuilder();
305         builder.finish(TrainingConstraints.createTrainingConstraints(builder, true, true, true));
306         return builder.sizedByteArray();
307     }
308 
buildTrainingIntervalOptions( @ullable TrainingInterval trainingInterval)309     private static byte[] buildTrainingIntervalOptions(
310             @Nullable TrainingInterval trainingInterval) {
311         FlatBufferBuilder builder = new FlatBufferBuilder();
312         if (trainingInterval == null) {
313             builder.finish(
314                     TrainingIntervalOptions.createTrainingIntervalOptions(
315                             builder, SchedulingMode.ONE_TIME, 0));
316             return builder.sizedByteArray();
317         }
318         builder.finish(
319                 TrainingIntervalOptions.createTrainingIntervalOptions(
320                         builder,
321                         convertSchedulingMode(trainingInterval.getSchedulingMode()),
322                         trainingInterval.getMinimumIntervalMillis()));
323 
324         return builder.sizedByteArray();
325     }
326 
detectKeyParametersChanged( TrainingOptions newTaskOptions, FederatedTrainingTask existingTask, Set<FederatedTrainingTask> trainingTasksToCancel)327     private boolean detectKeyParametersChanged(
328             TrainingOptions newTaskOptions,
329             FederatedTrainingTask existingTask,
330             Set<FederatedTrainingTask> trainingTasksToCancel) {
331         // Check if the task previously had a different JobScheduler job ID. If it did then
332         // cancel that job for that old ID so it's not left hanging.
333         boolean jobIdChanged = existingTask.jobId() != newTaskOptions.getJobSchedulerJobId();
334         if (jobIdChanged) {
335             Log.i(
336                     TAG,
337                     String.format(
338                             "JobScheduler job id changed from %d to %d",
339                             existingTask.jobId(), newTaskOptions.getJobSchedulerJobId()));
340             trainingTasksToCancel.add(existingTask);
341         }
342 
343         // Check if the task previously had a different population name.
344         boolean populationChanged =
345                 !existingTask.populationName().equals(newTaskOptions.getPopulationName());
346         if (populationChanged) {
347             Log.i(
348                     TAG,
349                     String.format(
350                             "JobScheduler population name changed from %s to %s",
351                             existingTask.populationName(), newTaskOptions.getPopulationName()));
352         }
353 
354         boolean trainingIntervalChanged = trainingIntervalChanged(newTaskOptions, existingTask);
355         if (trainingIntervalChanged) {
356             Log.i(
357                     TAG,
358                     String.format(
359                             "JobScheduler training interval changed from %s to %s",
360                             existingTask.getTrainingIntervalOptions(),
361                             newTaskOptions.getTrainingInterval()));
362         }
363         return jobIdChanged || populationChanged || trainingIntervalChanged;
364     }
365 
trainingIntervalChanged( TrainingOptions newTaskOptions, FederatedTrainingTask existingTask)366     private static boolean trainingIntervalChanged(
367             TrainingOptions newTaskOptions, FederatedTrainingTask existingTask) {
368         byte[] incomingTrainingIntervalOptions =
369                 newTaskOptions.getTrainingInterval() == null
370                         ? null
371                         : buildTrainingIntervalOptions(newTaskOptions.getTrainingInterval());
372         return !Arrays.equals(incomingTrainingIntervalOptions, existingTask.intervalOptions());
373     }
374 
sendError(@onNull IFederatedComputeCallback callback)375     private void sendError(@NonNull IFederatedComputeCallback callback) {
376         try {
377             callback.onFailure(STATUS_INTERNAL_ERROR);
378         } catch (RemoteException e) {
379             Log.e(TAG, "IFederatedComputeCallback error", e);
380         }
381     }
382 
sendSuccess(@onNull IFederatedComputeCallback callback)383     private void sendSuccess(@NonNull IFederatedComputeCallback callback) {
384         try {
385             callback.onSuccess();
386         } catch (RemoteException e) {
387             Log.e(TAG, "IFederatedComputeCallback error", e);
388         }
389     }
390 }
391