1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.training; 18 19 import android.annotation.NonNull; 20 import android.annotation.Nullable; 21 import android.content.Context; 22 import android.util.Log; 23 24 import com.android.federatedcompute.services.common.Flags; 25 import com.android.federatedcompute.services.common.PhFlags; 26 import com.android.federatedcompute.services.common.TaskRetry; 27 import com.android.federatedcompute.services.common.TrainingResult; 28 import com.android.federatedcompute.services.data.FederatedTrainingTask; 29 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; 30 import com.android.federatedcompute.services.scheduling.SchedulingUtil; 31 32 import com.google.common.annotations.VisibleForTesting; 33 34 import javax.annotation.concurrent.GuardedBy; 35 36 /** The worker to execute federated computation jobs. */ 37 public class FederatedComputeWorker { 38 private static final String TAG = "FederatedComputeWorker"; 39 40 static final Object LOCK = new Object(); 41 42 @GuardedBy("LOCK") 43 @Nullable 44 private TrainingRun mActiveRun = null; 45 46 @Nullable private final FederatedComputeJobManager mJobManager; 47 private static volatile FederatedComputeWorker sFederatedComputeWorker; 48 private final Flags mFlags; 49 50 @VisibleForTesting FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags)51 public FederatedComputeWorker(FederatedComputeJobManager jobManager, Flags flags) { 52 this.mJobManager = jobManager; 53 this.mFlags = flags; 54 } 55 56 /** Gets an instance of {@link FederatedComputeWorker}. */ 57 @NonNull getInstance(Context context)58 public static FederatedComputeWorker getInstance(Context context) { 59 synchronized (FederatedComputeWorker.class) { 60 if (sFederatedComputeWorker == null) { 61 sFederatedComputeWorker = 62 new FederatedComputeWorker( 63 FederatedComputeJobManager.getInstance(context), 64 PhFlags.getInstance()); 65 } 66 return sFederatedComputeWorker; 67 } 68 } 69 70 /** Starts a training run with the given job Id. */ startTrainingRun(int jobId)71 public boolean startTrainingRun(int jobId) { 72 Log.d(TAG, "startTrainingRun()"); 73 FederatedTrainingTask trainingTask = mJobManager.onTrainingStarted(jobId); 74 if (trainingTask == null) { 75 Log.i(TAG, String.format("Could not find task to run for job ID %s", jobId)); 76 return false; 77 } 78 79 synchronized (LOCK) { 80 // Only allow one concurrent federated computation job. 81 if (mActiveRun != null) { 82 Log.i( 83 TAG, 84 String.format( 85 "Delaying %d/%s another run is already active!", 86 jobId, trainingTask.populationName())); 87 mJobManager.onTrainingCompleted( 88 jobId, 89 trainingTask.populationName(), 90 trainingTask.getTrainingIntervalOptions(), 91 /* taskRetry= */ null, 92 TrainingResult.FAIL); 93 return false; 94 } 95 TrainingRun run = new TrainingRun(jobId, trainingTask); 96 this.mActiveRun = run; 97 doTraining(run); 98 // TODO: get retry info from federated server. 99 TaskRetry taskRetry = SchedulingUtil.generateTransientErrorTaskRetry(mFlags); 100 finish(this.mActiveRun, taskRetry, TrainingResult.SUCCESS); 101 } 102 return true; 103 } 104 105 /** Cancels the running job if present. */ cancelActiveRun()106 public void cancelActiveRun() { 107 Log.d(TAG, "cancelActiveRun()"); 108 synchronized (LOCK) { 109 if (mActiveRun == null) { 110 return; 111 } 112 finish(mActiveRun, /* taskRetry= */ null, TrainingResult.FAIL); 113 } 114 } 115 finish( TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult)116 private void finish( 117 TrainingRun runToFinish, TaskRetry taskRetry, @TrainingResult int trainingResult) { 118 synchronized (LOCK) { 119 if (mActiveRun != runToFinish) { 120 return; 121 } 122 mActiveRun = null; 123 mJobManager.onTrainingCompleted( 124 runToFinish.mJobId, 125 runToFinish.mTask.populationName(), 126 runToFinish.mTask.getTrainingIntervalOptions(), 127 taskRetry, 128 trainingResult); 129 } 130 } 131 doTraining(TrainingRun run)132 private void doTraining(TrainingRun run) { 133 // TODO: add training logic. 134 Log.d(TAG, "Start run training job " + run.mJobId); 135 } 136 137 private static final class TrainingRun { 138 private final int mJobId; 139 private final FederatedTrainingTask mTask; 140 TrainingRun(int jobId, FederatedTrainingTask task)141 private TrainingRun(int jobId, FederatedTrainingTask task) { 142 this.mJobId = jobId; 143 this.mTask = task; 144 } 145 } 146 } 147