1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.ondevicepersonalization.services.federatedcompute; 18 19 import android.adservices.ondevicepersonalization.Constants; 20 import android.adservices.ondevicepersonalization.IsolatedServiceException; 21 import android.adservices.ondevicepersonalization.TrainingExampleRecord; 22 import android.adservices.ondevicepersonalization.TrainingExamplesInputParcel; 23 import android.adservices.ondevicepersonalization.TrainingExamplesOutputParcel; 24 import android.adservices.ondevicepersonalization.UserData; 25 import android.annotation.NonNull; 26 import android.content.ComponentName; 27 import android.content.Context; 28 import android.federatedcompute.ExampleStoreService; 29 import android.federatedcompute.FederatedComputeManager; 30 import android.federatedcompute.common.ClientConstants; 31 import android.os.Bundle; 32 import android.os.OutcomeReceiver; 33 34 import com.android.odp.module.common.Clock; 35 import com.android.odp.module.common.MonotonicClock; 36 import com.android.odp.module.common.PackageUtils; 37 import com.android.ondevicepersonalization.internal.util.LoggerFactory; 38 import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice; 39 import com.android.ondevicepersonalization.services.Flags; 40 import com.android.ondevicepersonalization.services.FlagsFactory; 41 import com.android.ondevicepersonalization.services.OdpServiceException; 42 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors; 43 import com.android.ondevicepersonalization.services.data.DataAccessPermission; 44 import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl; 45 import com.android.ondevicepersonalization.services.data.events.EventState; 46 import com.android.ondevicepersonalization.services.data.events.EventsDao; 47 import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus; 48 import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper; 49 import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor; 50 import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo; 51 import com.android.ondevicepersonalization.services.process.ProcessRunner; 52 import com.android.ondevicepersonalization.services.process.ProcessRunnerFactory; 53 import com.android.ondevicepersonalization.services.util.AllowListUtils; 54 import com.android.ondevicepersonalization.services.util.StatsUtils; 55 56 import com.google.common.util.concurrent.FluentFuture; 57 import com.google.common.util.concurrent.FutureCallback; 58 import com.google.common.util.concurrent.Futures; 59 import com.google.common.util.concurrent.ListenableFuture; 60 import com.google.common.util.concurrent.ListeningScheduledExecutorService; 61 62 import java.util.Objects; 63 import java.util.concurrent.CountDownLatch; 64 import java.util.concurrent.TimeUnit; 65 import java.util.concurrent.TimeoutException; 66 67 /** Implementation of ExampleStoreService for OnDevicePersonalization */ 68 public final class OdpExampleStoreService extends ExampleStoreService { 69 70 private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger(); 71 private static final String TAG = OdpExampleStoreService.class.getSimpleName(); 72 private static final String TASK_NAME = "ExampleStore"; 73 74 static class Injector { getClock()75 Clock getClock() { 76 return MonotonicClock.getInstance(); 77 } 78 getFlags()79 Flags getFlags() { 80 return FlagsFactory.getFlags(); 81 } 82 getScheduledExecutor()83 ListeningScheduledExecutorService getScheduledExecutor() { 84 return OnDevicePersonalizationExecutors.getScheduledExecutor(); 85 } 86 getProcessRunner()87 ProcessRunner getProcessRunner() { 88 return ProcessRunnerFactory.getProcessRunner(); 89 } 90 } 91 92 private final Injector mInjector = new Injector(); 93 94 /** Generates a unique task identifier from the given strings */ getTaskIdentifier(String populationName, String taskId)95 public static String getTaskIdentifier(String populationName, String taskId) { 96 return populationName + "_" + taskId; 97 } 98 99 /** Generates a unique task identifier from the given strings */ getTaskIdentifier( String populationName, String taskId, String collectionUri)100 public static String getTaskIdentifier( 101 String populationName, String taskId, String collectionUri) { 102 return populationName + "_" + taskId + "_" + collectionUri; 103 } 104 isCollectionUriPresent(String collectionUri)105 private static boolean isCollectionUriPresent(String collectionUri) { 106 return collectionUri != null && !collectionUri.isEmpty(); 107 } 108 109 @Override startQuery(@onNull Bundle params, @NonNull QueryCallback callback)110 public void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback) { 111 try { 112 long startTime = mInjector.getClock().currentTimeMillis(); 113 ContextData contextData = 114 ContextData.fromByteArray( 115 Objects.requireNonNull( 116 params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA))); 117 String packageName = contextData.getPackageName(); 118 String ownerClassName = contextData.getClassName(); 119 String populationName = 120 Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME)); 121 String taskId = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_ID)); 122 String collectionUri = params.getString(ClientConstants.EXTRA_COLLECTION_URI); 123 int eligibilityMinExample = 124 params.getInt(ClientConstants.EXTRA_ELIGIBILITY_MIN_EXAMPLE); 125 126 EventsDao eventDao = EventsDao.getInstance(getContext()); 127 128 boolean privacyStatusEligible = true; 129 130 if (!UserPrivacyStatus.getInstance().isMeasurementEnabled()) { 131 privacyStatusEligible = false; 132 sLogger.w(TAG + ": Measurement control is not given."); 133 StatsUtils.writeServiceRequestMetrics( 134 Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE, 135 packageName, 136 null, 137 mInjector.getClock(), 138 Constants.STATUS_PERSONALIZATION_DISABLED, 139 startTime); 140 } 141 142 // Cancel job if on longer valid. This is written to the table during scheduling 143 // via {@link FederatedComputeServiceImpl} and deleted either during cancel or 144 // during maintenance for uninstalled packages. 145 ComponentName owner = ComponentName.createRelative(packageName, ownerClassName); 146 EventState eventStatePopulation = eventDao.getEventState(populationName, owner); 147 if (eventStatePopulation == null) { 148 StatsUtils.writeServiceRequestMetrics( 149 Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE, 150 packageName, 151 null, 152 mInjector.getClock(), 153 Constants.STATUS_KEY_NOT_FOUND, 154 startTime); 155 } 156 if (!privacyStatusEligible || eventStatePopulation == null) { 157 sLogger.w("Job was either cancelled or package was uninstalled"); 158 // Cancel job. 159 FederatedComputeManager FCManager = 160 getContext().getSystemService(FederatedComputeManager.class); 161 if (FCManager == null) { 162 sLogger.e(TAG + ": Failed to get FederatedCompute Service"); 163 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); 164 return; 165 } 166 FCManager.cancel( 167 owner, 168 populationName, 169 OnDevicePersonalizationExecutors.getBackgroundExecutor(), 170 new OutcomeReceiver<Object, Exception>() { 171 @Override 172 public void onResult(Object result) { 173 sLogger.d(TAG + ": Successfully canceled job"); 174 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); 175 } 176 177 @Override 178 public void onError(Exception error) { 179 sLogger.e(TAG + ": Error while cancelling job", error); 180 OutcomeReceiver.super.onError(error); 181 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); 182 } 183 }); 184 return; 185 } 186 187 // Get resumptionToken 188 EventState eventState = 189 eventDao.getEventState( 190 isCollectionUriPresent(collectionUri) 191 ? getTaskIdentifier(populationName, taskId, collectionUri) 192 : getTaskIdentifier(populationName, taskId), 193 owner); 194 byte[] resumptionToken = null; 195 if (eventState != null) { 196 resumptionToken = eventState.getToken(); 197 } 198 199 TrainingExamplesInputParcel.Builder input = 200 new TrainingExamplesInputParcel.Builder() 201 .setResumptionToken(resumptionToken) 202 .setPopulationName(populationName) 203 .setTaskName(taskId); 204 if (isCollectionUriPresent(collectionUri)) { 205 input.setCollectionName(collectionUri); 206 } 207 208 String className = 209 AppManifestConfigHelper.getServiceNameFromOdpSettings( 210 getContext(), packageName); 211 ListenableFuture<IsolatedServiceInfo> loadFuture = 212 mInjector 213 .getProcessRunner() 214 .loadIsolatedService( 215 TASK_NAME, 216 ComponentName.createRelative(packageName, className)); 217 ListenableFuture<Bundle> resultFuture = 218 FluentFuture.from(loadFuture) 219 .transformAsync( 220 result -> 221 executeOnTrainingExamples( 222 result, input.build(), packageName), 223 OnDevicePersonalizationExecutors.getBackgroundExecutor()) 224 .withTimeout( 225 mInjector.getFlags().getIsolatedServiceDeadlineSeconds(), 226 TimeUnit.SECONDS, 227 mInjector.getScheduledExecutor()); 228 229 CountDownLatch latch = new CountDownLatch(1); 230 Futures.addCallback( 231 resultFuture, 232 new FutureCallback<Bundle>() { 233 @Override 234 public void onSuccess(Bundle result) { 235 int status = Constants.STATUS_SUCCESS; 236 try { 237 TrainingExamplesOutputParcel trainingExamplesOutputParcel = 238 result.getParcelable( 239 Constants.EXTRA_RESULT, 240 TrainingExamplesOutputParcel.class); 241 if (trainingExamplesOutputParcel == null) { 242 status = Constants.STATUS_NAME_NOT_FOUND; 243 callback.onStartQueryFailure( 244 ClientConstants.STATUS_INTERNAL_ERROR); 245 return; 246 } 247 OdpParceledListSlice<TrainingExampleRecord> 248 trainingExampleRecordList = 249 trainingExamplesOutputParcel 250 .getTrainingExampleRecords(); 251 if (trainingExampleRecordList == null 252 || trainingExampleRecordList.getList().isEmpty()) { 253 status = Constants.STATUS_SUCCESS_EMPTY_RESULT; 254 callback.onStartQueryFailure( 255 ClientConstants.STATUS_NOT_ENOUGH_DATA); 256 } else if (trainingExampleRecordList.getList().size() 257 < eligibilityMinExample) { 258 sLogger.d(TAG + ": not enough examples, requires %d got %d", 259 eligibilityMinExample, 260 trainingExampleRecordList.getList().size()); 261 status = Constants.STATUS_SUCCESS_NOT_ENOUGH_DATA; 262 callback.onStartQueryFailure( 263 ClientConstants.STATUS_NOT_ENOUGH_DATA); 264 } else { 265 callback.onStartQuerySuccess( 266 OdpExampleStoreIteratorFactory.getInstance() 267 .createIterator( 268 trainingExampleRecordList.getList())); 269 } 270 } finally { 271 latch.countDown(); 272 StatsUtils.writeServiceRequestMetrics( 273 Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE, 274 packageName, 275 result, 276 mInjector.getClock(), 277 status, 278 startTime); 279 } 280 } 281 282 @Override 283 public void onFailure(Throwable t) { 284 latch.countDown(); 285 int status = Constants.STATUS_INTERNAL_ERROR; 286 if (t instanceof TimeoutException) { 287 status = Constants.STATUS_TIMEOUT; 288 } else if (t instanceof OdpServiceException exp) { 289 if (exp.getCause() instanceof IsolatedServiceException 290 && isLogIsolatedServiceErrorCodeNonAggregatedAllowed( 291 packageName)) { 292 status = ((IsolatedServiceException) exp.getCause()) 293 .getErrorCode(); 294 } else { 295 status = exp.getErrorCode(); 296 } 297 } 298 sLogger.w(t, "%s : Request failed.", TAG); 299 StatsUtils.writeServiceRequestMetrics( 300 Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE, 301 packageName, 302 null, 303 mInjector.getClock(), 304 status, 305 startTime); 306 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); 307 } 308 }, 309 OnDevicePersonalizationExecutors.getBackgroundExecutor()); 310 311 var unused = 312 Futures.whenAllComplete(loadFuture, resultFuture) 313 .callAsync( 314 () -> { 315 try { 316 latch.await(); 317 } catch (InterruptedException e) { 318 sLogger.e(e, "%s : Interrupted while " 319 + "waiting for transaction complete", TAG); 320 } 321 return mInjector 322 .getProcessRunner() 323 .unloadIsolatedService(loadFuture.get()); 324 }, 325 OnDevicePersonalizationExecutors.getBackgroundExecutor()); 326 } catch (Throwable e) { 327 sLogger.e(e, "%s : Start query failed.", TAG); 328 StatsUtils.writeServiceRequestMetrics( 329 Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE, 330 Constants.STATUS_INTERNAL_ERROR); 331 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR); 332 } 333 } 334 executeOnTrainingExamples( IsolatedServiceInfo isolatedServiceInfo, TrainingExamplesInputParcel exampleInput, String packageName)335 private ListenableFuture<Bundle> executeOnTrainingExamples( 336 IsolatedServiceInfo isolatedServiceInfo, 337 TrainingExamplesInputParcel exampleInput, 338 String packageName) { 339 sLogger.d(TAG + ": executeOnTrainingExamples() started."); 340 Bundle serviceParams = new Bundle(); 341 serviceParams.putParcelable(Constants.EXTRA_INPUT, exampleInput); 342 String serviceClass = 343 AppManifestConfigHelper.getServiceNameFromOdpSettings(getContext(), packageName); 344 DataAccessServiceImpl binder = 345 new DataAccessServiceImpl( 346 ComponentName.createRelative(packageName, serviceClass), 347 getContext(), 348 // ODP provides accurate user signal in training flow, so we disable write 349 // access of databases to prevent leak. 350 /* localDataPermission */ DataAccessPermission.READ_ONLY, 351 /* eventDataPermission */ DataAccessPermission.READ_ONLY); 352 serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder); 353 UserDataAccessor userDataAccessor = new UserDataAccessor(); 354 UserData userData; 355 // By default, we don't provide platform data for federated learning flow. 356 if (isPlatformDataProvided(packageName)) { 357 userData = userDataAccessor.getUserDataWithAppInstall(); 358 } else { 359 userData = userDataAccessor.getUserData(); 360 } 361 serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData); 362 return mInjector 363 .getProcessRunner() 364 .runIsolatedService( 365 isolatedServiceInfo, Constants.OP_TRAINING_EXAMPLE, serviceParams); 366 } 367 368 // used for tests to provide mock/real implementation of context. getContext()369 private Context getContext() { 370 return this.getApplicationContext(); 371 } 372 isPlatformDataProvided(String packageName)373 private boolean isPlatformDataProvided(String packageName) { 374 try { 375 return AllowListUtils.isAllowListed( 376 packageName, 377 PackageUtils.getCertDigest(getContext(), packageName), 378 mInjector.getFlags().getDefaultPlatformDataForExecuteAllowlist()); 379 } catch (Exception e) { 380 sLogger.d(TAG + ": allow list error", e); 381 return false; 382 } 383 } 384 isLogIsolatedServiceErrorCodeNonAggregatedAllowed(String packageName)385 private boolean isLogIsolatedServiceErrorCodeNonAggregatedAllowed(String packageName) { 386 try { 387 return AllowListUtils.isAllowListed( 388 packageName, 389 null, 390 mInjector.getFlags().getLogIsolatedServiceErrorCodeNonAggregatedAllowlist()); 391 } catch (Exception e) { 392 sLogger.d(e, TAG + ": check isLogIsolatedServiceErrorCodeNonAggregatedAllowed error"); 393 return false; 394 } 395 } 396 } 397