1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.scheduling; 18 19 import static java.lang.Math.max; 20 import static java.lang.Math.min; 21 import static java.util.concurrent.TimeUnit.SECONDS; 22 23 import android.federatedcompute.common.TrainingInterval; 24 import android.federatedcompute.common.TrainingOptions; 25 26 import com.android.federatedcompute.services.common.Flags; 27 import com.android.federatedcompute.services.common.TaskRetry; 28 import com.android.federatedcompute.services.data.FederatedTrainingTask; 29 import com.android.federatedcompute.services.data.fbs.SchedulingMode; 30 import com.android.federatedcompute.services.data.fbs.TrainingIntervalOptions; 31 32 import java.util.Random; 33 34 /** The util function about federated job scheduler. */ 35 public class SchedulingUtil { SchedulingUtil()36 private SchedulingUtil() {} 37 38 /** Gets the next run time when federated compute job finishes. */ getEarliestRuntimeForFCReschedule( long nowMs, TrainingIntervalOptions interval, TaskRetry taskRetry, boolean hasContributed, Flags flags)39 public static long getEarliestRuntimeForFCReschedule( 40 long nowMs, 41 TrainingIntervalOptions interval, 42 TaskRetry taskRetry, 43 boolean hasContributed, 44 Flags flags) { 45 long newLatencyMillis; 46 if (taskRetry == null || (taskRetry.getMinDelay() <= 0 && taskRetry.getMaxDelay() <= 0)) { 47 TaskRetry transientErrorRetry = generateTransientErrorTaskRetry(flags); 48 newLatencyMillis = 49 generateMinimumDelayMillisFromRange( 50 transientErrorRetry.getMinDelay(), transientErrorRetry.getMaxDelay()); 51 } else { 52 long unsanitizedMillis = 53 generateMinimumDelayMillisFromRange( 54 taskRetry.getMinDelay(), taskRetry.getMaxDelay()); 55 long serverSpecifiedLatency = 56 sanitizeMinimumLatencyMillis( 57 unsanitizedMillis, SchedulingMode.UNDEFINED, flags); 58 if (interval.schedulingMode() == SchedulingMode.RECURRENT && hasContributed) { 59 // Only use the user-specified retry latency if we actually successfully published a 60 // result to the server. 61 long userSpecifiedLatency = interval.minIntervalMillis(); 62 userSpecifiedLatency = 63 sanitizeMinimumLatencyMillis( 64 userSpecifiedLatency, SchedulingMode.RECURRENT, flags); 65 newLatencyMillis = max(userSpecifiedLatency, serverSpecifiedLatency); 66 } else { 67 // Use server defined retry window 68 newLatencyMillis = serverSpecifiedLatency; 69 } 70 } 71 return nowMs + newLatencyMillis; 72 } 73 74 /** Gets the next run time when first time schedule the federated compute job. */ getEarliestRuntimeForInitialSchedule( long nowMs, long lastRunTimeMs, TrainingOptions trainerOptions, Flags flags)75 public static long getEarliestRuntimeForInitialSchedule( 76 long nowMs, long lastRunTimeMs, TrainingOptions trainerOptions, Flags flags) { 77 long defaultNextRunTimeMs = 78 nowMs + SECONDS.toMillis(flags.getDefaultSchedulingPeriodSecs()); 79 int schedulingMode = 80 trainerOptions.getTrainingInterval() != null 81 ? convertSchedulingMode( 82 trainerOptions.getTrainingInterval().getSchedulingMode()) 83 : SchedulingMode.ONE_TIME; 84 if (schedulingMode != SchedulingMode.RECURRENT) { 85 // Non-recurrent task doesn't have user defined interval. 86 return defaultNextRunTimeMs; 87 } 88 89 long userDefinedMinIntervalMillis = 90 sanitizeMinimumLatencyMillis( 91 trainerOptions.getTrainingInterval() == null 92 ? 0 93 : trainerOptions.getTrainingInterval().getMinimumIntervalMillis(), 94 schedulingMode, 95 flags); 96 // Take the smaller value of default next run time, and the next run time with user defined 97 // interval. 98 long minIntervalMsForRecurrentTask = 99 min(nowMs + userDefinedMinIntervalMillis, defaultNextRunTimeMs); 100 101 if (lastRunTimeMs == 0) { 102 // The task has never run in the past 103 return minIntervalMsForRecurrentTask; 104 } else { 105 // If the task has run in the past, we want to make sure the user defined minimum 106 // interval has passed since last time it ran. 107 return max(lastRunTimeMs + userDefinedMinIntervalMillis, minIntervalMsForRecurrentTask); 108 } 109 } 110 111 /** Gets the next run time when the federated job with same job id may be running. */ getEarliestRuntimeForExistingTask( FederatedTrainingTask existingTask, TrainingOptions trainingOptions, Flags flags, long nowMs)112 public static long getEarliestRuntimeForExistingTask( 113 FederatedTrainingTask existingTask, 114 TrainingOptions trainingOptions, 115 Flags flags, 116 long nowMs) { 117 long existingTaskMinLatencyMillis = existingTask.earliestNextRunTime() - nowMs; 118 int schedulingMode = 119 trainingOptions.getTrainingInterval() != null 120 ? convertSchedulingMode( 121 trainingOptions.getTrainingInterval().getSchedulingMode()) 122 : SchedulingMode.ONE_TIME; 123 long sanitizedMinLatencyMillis = 124 sanitizeMinimumLatencyMillis(existingTaskMinLatencyMillis, schedulingMode, flags); 125 return nowMs + sanitizedMinLatencyMillis; 126 } 127 128 /** Gets the task retry range for transient error happens and worth retry. */ generateTransientErrorTaskRetry(Flags flags)129 public static TaskRetry generateTransientErrorTaskRetry(Flags flags) { 130 double jitterPercent = min(1.0, max(0.0, flags.getTransientErrorRetryDelayJitterPercent())); 131 long targetDelayMillis = SECONDS.toMillis(flags.getTransientErrorRetryDelaySecs()); 132 long maxDelay = (long) (targetDelayMillis * (1.0 + jitterPercent)); 133 long minDelay = (long) (targetDelayMillis * (1.0 - jitterPercent)); 134 return new TaskRetry.Builder().setMaxDelay(maxDelay).setMinDelay(minDelay).build(); 135 } 136 137 /** Generates a random delay between the provided min and max values. */ generateMinimumDelayMillisFromRange(long minMillis, long maxMillis)138 private static long generateMinimumDelayMillisFromRange(long minMillis, long maxMillis) { 139 // Sanitize the min/max values. 140 minMillis = max(0, minMillis); 141 maxMillis = max(minMillis, maxMillis); 142 Random randomGen = new Random(); 143 return minMillis + (long) ((double) (maxMillis - minMillis) * randomGen.nextDouble()); 144 } 145 sanitizeMinimumLatencyMillis( long unsanitizedMillis, int schedulingMode, Flags flags)146 private static long sanitizeMinimumLatencyMillis( 147 long unsanitizedMillis, int schedulingMode, Flags flags) { 148 long lowerBoundMillis; 149 long upperBoundMillis; 150 if (schedulingMode == SchedulingMode.RECURRENT) { 151 // Recurrent task with user defined interval 152 lowerBoundMillis = 153 SECONDS.toMillis(flags.getMinSchedulingIntervalSecsForFederatedComputation()); 154 upperBoundMillis = 155 SECONDS.toMillis(flags.getMaxSchedulingIntervalSecsForFederatedComputation()); 156 } else { 157 // One-time task or recurrent task without user defined interval 158 lowerBoundMillis = 0L; 159 upperBoundMillis = SECONDS.toMillis(flags.getMaxSchedulingPeriodSecs()); 160 } 161 return max(lowerBoundMillis, min(upperBoundMillis, unsanitizedMillis)); 162 } 163 164 /** Converts from TrainingOptions SchedulingMode to the storage fbs.SchedulingMode. */ convertSchedulingMode(@rainingInterval.SchedulingMode int schedulingMode)165 public static int convertSchedulingMode(@TrainingInterval.SchedulingMode int schedulingMode) { 166 if (schedulingMode == TrainingInterval.SCHEDULING_MODE_RECURRENT) { 167 return SchedulingMode.RECURRENT; 168 } else if (schedulingMode == TrainingInterval.SCHEDULING_MODE_ONE_TIME) { 169 return SchedulingMode.ONE_TIME; 170 } else { 171 throw new IllegalStateException("Unknown value for scheduling mode"); 172 } 173 } 174 } 175