1 /* 2 * Copyright 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import static com.google.common.truth.Truth.assertThat; 20 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertNotNull; 23 import static org.junit.Assert.assertNull; 24 import static org.junit.Assert.assertThrows; 25 import static org.junit.Assert.assertTrue; 26 27 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 28 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; 29 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; 30 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; 31 import android.federatedcompute.common.TrainingOptions; 32 import android.os.Bundle; 33 import android.os.RemoteException; 34 35 import androidx.test.ext.junit.runners.AndroidJUnit4; 36 import androidx.test.filters.SmallTest; 37 38 import com.android.ondevicepersonalization.testing.utils.ResultReceiver; 39 40 import com.google.common.util.concurrent.MoreExecutors; 41 42 import org.junit.Test; 43 import org.junit.runner.RunWith; 44 45 import java.time.Duration; 46 47 /** Unit Tests for {@link FederatedComputeScheduler}. */ 48 @SmallTest 49 @RunWith(AndroidJUnit4.class) 50 public class FederatedComputeSchedulerTest { 51 52 private static final String VALID_POPULATION_NAME = "population"; 53 private static final String ERROR_POPULATION_NAME = "err"; 54 55 private static final String INVALID_MANIFEST_ERROR_POPULATION_NAME = "manifest_error"; 56 private static final String POPULATION_NAME_PRIVACY_NOT_ELIGIBLE = "privacy_not_eligible"; 57 58 private static final TrainingInterval TEST_TRAINING_INTERVAL = 59 new TrainingInterval.Builder() 60 .setMinimumInterval(Duration.ofHours(10)) 61 .setSchedulingMode(TrainingInterval.SCHEDULING_MODE_ONE_TIME) 62 .build(); 63 64 private static final FederatedComputeScheduler.Params TEST_SCHEDULER_PARAMS = 65 new FederatedComputeScheduler.Params(TEST_TRAINING_INTERVAL); 66 67 private static final FederatedComputeInput TEST_FC_INPUT = 68 new FederatedComputeInput.Builder().setPopulationName(VALID_POPULATION_NAME).build(); 69 private static final FederatedComputeScheduleRequest TEST_SCHEDULE_INPUT = 70 new FederatedComputeScheduleRequest(TEST_SCHEDULER_PARAMS, VALID_POPULATION_NAME); 71 72 private final FederatedComputeScheduler mFederatedComputeScheduler = 73 new FederatedComputeScheduler( 74 IFederatedComputeService.Stub.asInterface(new FederatedComputeService()), 75 IDataAccessService.Stub.asInterface(new TestDataService())); 76 77 private boolean mCancelCalled = false; 78 private boolean mScheduleCalled = false; 79 private boolean mLogApiCalled = false; 80 private int mResponseCode = Constants.STATUS_SUCCESS; 81 82 @Test testScheduleSuccess()83 public void testScheduleSuccess() { 84 mFederatedComputeScheduler.schedule(TEST_SCHEDULER_PARAMS, TEST_FC_INPUT); 85 86 assertThat(mScheduleCalled).isTrue(); 87 assertThat(mLogApiCalled).isTrue(); 88 assertThat(mResponseCode).isEqualTo(Constants.STATUS_SUCCESS); 89 } 90 91 @Test testSchedule_withOutcomeReceiver_success()92 public void testSchedule_withOutcomeReceiver_success() throws Exception { 93 var receiver = new ResultReceiver(); 94 95 mFederatedComputeScheduler.schedule( 96 TEST_SCHEDULE_INPUT, MoreExecutors.directExecutor(), receiver); 97 98 assertNotNull(receiver.getResult()); 99 assertTrue(receiver.isSuccess()); 100 assertThat(mScheduleCalled).isTrue(); 101 assertThat(mLogApiCalled).isTrue(); 102 assertThat(mResponseCode).isEqualTo(Constants.STATUS_SUCCESS); 103 } 104 105 @Test testSchedule_withOutcomeReceiver_error()106 public void testSchedule_withOutcomeReceiver_error() throws Exception { 107 FederatedComputeScheduleRequest scheduleInput = 108 new FederatedComputeScheduleRequest(TEST_SCHEDULER_PARAMS, ERROR_POPULATION_NAME); 109 var receiver = new ResultReceiver(); 110 111 mFederatedComputeScheduler.schedule( 112 scheduleInput, MoreExecutors.directExecutor(), receiver); 113 114 assertNull(receiver.getResult()); 115 assertTrue(receiver.isError()); 116 assertTrue(receiver.getException() instanceof OnDevicePersonalizationException); 117 assertEquals( 118 OnDevicePersonalizationException.ERROR_SCHEDULE_TRAINING_FAILED, 119 ((OnDevicePersonalizationException) receiver.getException()).getErrorCode()); 120 assertThat(mScheduleCalled).isTrue(); 121 assertThat(mLogApiCalled).isTrue(); 122 assertThat(mResponseCode).isEqualTo(Constants.STATUS_INTERNAL_ERROR); 123 } 124 125 @Test testSchedule_withOutcomeReceiver_manifestError()126 public void testSchedule_withOutcomeReceiver_manifestError() throws Exception { 127 FederatedComputeScheduleRequest scheduleInput = 128 new FederatedComputeScheduleRequest( 129 TEST_SCHEDULER_PARAMS, INVALID_MANIFEST_ERROR_POPULATION_NAME); 130 var receiver = new ResultReceiver(); 131 132 mFederatedComputeScheduler.schedule( 133 scheduleInput, MoreExecutors.directExecutor(), receiver); 134 135 assertNull(receiver.getResult()); 136 assertTrue(receiver.isError()); 137 assertTrue(receiver.getException() instanceof OnDevicePersonalizationException); 138 assertEquals( 139 OnDevicePersonalizationException.ERROR_INVALID_TRAINING_MANIFEST, 140 ((OnDevicePersonalizationException) receiver.getException()).getErrorCode()); 141 assertThat(mScheduleCalled).isTrue(); 142 assertThat(mLogApiCalled).isTrue(); 143 assertThat(mResponseCode).isEqualTo(Constants.STATUS_FCP_MANIFEST_INVALID); 144 } 145 146 @Test testScheduleNull()147 public void testScheduleNull() { 148 FederatedComputeScheduler fcs = new FederatedComputeScheduler(null, new TestDataService()); 149 150 assertThrows( 151 IllegalStateException.class, 152 () -> fcs.schedule(TEST_SCHEDULER_PARAMS, TEST_FC_INPUT)); 153 assertThat(mResponseCode).isEqualTo(Constants.STATUS_INTERNAL_ERROR); 154 } 155 156 @Test testScheduleError()157 public void testScheduleError() { 158 FederatedComputeInput input = 159 new FederatedComputeInput.Builder() 160 .setPopulationName(ERROR_POPULATION_NAME) 161 .build(); 162 163 mFederatedComputeScheduler.schedule(TEST_SCHEDULER_PARAMS, input); 164 165 assertThat(mScheduleCalled).isTrue(); 166 assertThat(mLogApiCalled).isTrue(); 167 assertThat(mResponseCode).isEqualTo(Constants.STATUS_INTERNAL_ERROR); 168 } 169 170 @Test testSchedulePrivacyNotEligible()171 public void testSchedulePrivacyNotEligible() { 172 FederatedComputeInput input = 173 new FederatedComputeInput.Builder() 174 .setPopulationName(POPULATION_NAME_PRIVACY_NOT_ELIGIBLE) 175 .build(); 176 177 mFederatedComputeScheduler.schedule(TEST_SCHEDULER_PARAMS, input); 178 179 assertThat(mScheduleCalled).isTrue(); 180 assertThat(mLogApiCalled).isTrue(); 181 assertThat(mResponseCode).isEqualTo(Constants.STATUS_PERSONALIZATION_DISABLED); 182 } 183 184 @Test testCancelSuccess()185 public void testCancelSuccess() { 186 mFederatedComputeScheduler.cancel(TEST_FC_INPUT); 187 188 assertThat(mCancelCalled).isTrue(); 189 assertThat(mLogApiCalled).isTrue(); 190 assertThat(mResponseCode).isEqualTo(Constants.STATUS_SUCCESS); 191 } 192 193 @Test testCancelNull()194 public void testCancelNull() { 195 FederatedComputeScheduler fcs = new FederatedComputeScheduler(null, new TestDataService()); 196 197 assertThrows(IllegalStateException.class, () -> fcs.cancel(TEST_FC_INPUT)); 198 assertThat(mResponseCode).isEqualTo(Constants.STATUS_INTERNAL_ERROR); 199 } 200 201 @Test testCancelError()202 public void testCancelError() { 203 FederatedComputeInput input = 204 new FederatedComputeInput.Builder() 205 .setPopulationName(ERROR_POPULATION_NAME) 206 .build(); 207 208 mFederatedComputeScheduler.cancel(input); 209 210 assertThat(mCancelCalled).isTrue(); 211 assertThat(mLogApiCalled).isTrue(); 212 assertThat(mResponseCode).isEqualTo(Constants.STATUS_INTERNAL_ERROR); 213 } 214 215 private class FederatedComputeService extends IFederatedComputeService.Stub { 216 @Override schedule( TrainingOptions trainingOptions, IFederatedComputeCallback iFederatedComputeCallback)217 public void schedule( 218 TrainingOptions trainingOptions, 219 IFederatedComputeCallback iFederatedComputeCallback) 220 throws RemoteException { 221 mScheduleCalled = true; 222 if (trainingOptions.getPopulationName().equals(ERROR_POPULATION_NAME)) { 223 iFederatedComputeCallback.onFailure(Constants.STATUS_INTERNAL_ERROR); 224 return; 225 } 226 if (trainingOptions.getPopulationName().equals(POPULATION_NAME_PRIVACY_NOT_ELIGIBLE)) { 227 iFederatedComputeCallback.onFailure(Constants.STATUS_PERSONALIZATION_DISABLED); 228 return; 229 } 230 if (trainingOptions 231 .getPopulationName() 232 .equals(INVALID_MANIFEST_ERROR_POPULATION_NAME)) { 233 iFederatedComputeCallback.onFailure(Constants.STATUS_FCP_MANIFEST_INVALID); 234 return; 235 } 236 iFederatedComputeCallback.onSuccess(); 237 } 238 239 @Override cancel(String s, IFederatedComputeCallback iFederatedComputeCallback)240 public void cancel(String s, IFederatedComputeCallback iFederatedComputeCallback) 241 throws RemoteException { 242 mCancelCalled = true; 243 if (s.equals(ERROR_POPULATION_NAME)) { 244 iFederatedComputeCallback.onFailure(1); 245 return; 246 } 247 iFederatedComputeCallback.onSuccess(); 248 } 249 } 250 251 private class TestDataService extends IDataAccessService.Stub { 252 253 @Override onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)254 public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} 255 256 @Override logApiCallStats(int apiName, long latencyMillis, int responseCode)257 public void logApiCallStats(int apiName, long latencyMillis, int responseCode) { 258 mLogApiCalled = true; 259 mResponseCode = responseCode; 260 } 261 } 262 } 263