• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.training;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 import android.content.Context;
22 import android.util.Log;
23 
24 import com.android.federatedcompute.services.common.Flags;
25 import com.android.federatedcompute.services.common.PhFlags;
26 import com.android.federatedcompute.services.common.TaskRetry;
27 import com.android.federatedcompute.services.common.TrainingResult;
28 import com.android.federatedcompute.services.data.FederatedTrainingTask;
29 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
30 import com.android.federatedcompute.services.scheduling.SchedulingUtil;
31 
32 import com.google.common.annotations.VisibleForTesting;
33 
34 import javax.annotation.concurrent.GuardedBy;
35 
36 /** The worker to execute federated computation jobs. */
37 public class FederatedComputeWorker {
38     private static final String TAG = "FederatedComputeWorker";
39 
40     static final Object LOCK = new Object();
41 
42     @GuardedBy("LOCK")
43     @Nullable
44     private TrainingRun mActiveRun = null;
45 
46     @Nullable private final FederatedComputeJobManager mJobManager;
47     private static volatile FederatedComputeWorker sFederatedComputeWorker;
48     private final Flags mFlags;
49 
50     @VisibleForTesting
FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags)51     public FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags) {
52         this.mJobManager = jobManager;
53         this.mFlags = flags;
54     }
55 
56     /** Gets an instance of {@link FederatedComputeWorker}. */
57     @NonNull
getInstance(Context context)58     public static FederatedComputeWorker getInstance(Context context) {
59         synchronized (FederatedComputeWorker.class) {
60             if (sFederatedComputeWorker == null) {
61                 sFederatedComputeWorker =
62                         new FederatedComputeWorker(
63                                 FederatedComputeJobManager.getInstance(context),
64                                 PhFlags.getInstance());
65             }
66             return sFederatedComputeWorker;
67         }
68     }
69 
70     /** Starts a training run with the given job Id. */
startTrainingRun(int jobId)71     public boolean startTrainingRun(int jobId) {
72         Log.d(TAG, "startTrainingRun()");
73         FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId);
74         if (trainingTask == null) {
75             Log.i(TAG, String.format("Could not find task to run for job ID %s", jobId));
76             return false;
77         }
78 
79         synchronized (LOCK) {
80             // Only allow one concurrent federated computation job.
81             if (mActiveRun != null) {
82                 Log.i(
83                         TAG,
84                         String.format(
85                                 "Delaying %d/%s another run is already active!",
86                                 jobId, trainingTask.populationName()));
87                 mJobManager.onTrainingCompleted(
88                         jobId,
89                         trainingTask.populationName(),
90                         trainingTask.getTrainingIntervalOptions(),
91                         /* taskRetry= */ null,
92                         TrainingResult.FAIL);
93                 return false;
94             }
95             TrainingRun run = new TrainingRun(jobId, trainingTask);
96             this.mActiveRun = run;
97             doTraining(run);
98             // TODO: get retry info from federated server.
99             TaskRetry taskRetry = SchedulingUtil.generateTransientErrorTaskRetry(mFlags);
100             finish(this.mActiveRun, taskRetry, TrainingResult.SUCCESS);
101         }
102         return true;
103     }
104 
105     /** Cancels the running job if present. */
cancelActiveRun()106     public void cancelActiveRun() {
107         Log.d(TAG, "cancelActiveRun()");
108         synchronized (LOCK) {
109             if (mActiveRun == null) {
110                 return;
111             }
112             finish(mActiveRun, /* taskRetry= */ null, TrainingResult.FAIL);
113         }
114     }
115 
finish( TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult)116     private void finish(
117             TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult) {
118         synchronized (LOCK) {
119             if (mActiveRun != runToFinish) {
120                 return;
121             }
122             mActiveRun = null;
123             mJobManager.onTrainingCompleted(
124                     runToFinish.mJobId,
125                     runToFinish.mTask.populationName(),
126                     runToFinish.mTask.getTrainingIntervalOptions(),
127                     taskRetry,
128                     trainingResult);
129         }
130     }
131 
doTraining(TrainingRun run)132     private void doTraining(TrainingRun run) {
133         // TODO: add training logic.
134         Log.d(TAG, "Start run training job " + run.mJobId);
135     }
136 
137     private static final class TrainingRun {
138         private final int mJobId;
139         private final FederatedTrainingTask mTask;
140 
TrainingRun(int jobId, FederatedTrainingTask task)141         private TrainingRun(int jobId, FederatedTrainingTask task) {
142             this.mJobId = jobId;
143             this.mTask = task;
144         }
145     }
146 }
147