• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.ondevicepersonalization.services.federatedcompute;
18 
19 import android.adservices.ondevicepersonalization.Constants;
20 import android.adservices.ondevicepersonalization.IsolatedServiceException;
21 import android.adservices.ondevicepersonalization.TrainingExampleRecord;
22 import android.adservices.ondevicepersonalization.TrainingExamplesInputParcel;
23 import android.adservices.ondevicepersonalization.TrainingExamplesOutputParcel;
24 import android.adservices.ondevicepersonalization.UserData;
25 import android.annotation.NonNull;
26 import android.content.ComponentName;
27 import android.content.Context;
28 import android.federatedcompute.ExampleStoreService;
29 import android.federatedcompute.FederatedComputeManager;
30 import android.federatedcompute.common.ClientConstants;
31 import android.os.Bundle;
32 import android.os.OutcomeReceiver;
33 
34 import com.android.odp.module.common.Clock;
35 import com.android.odp.module.common.MonotonicClock;
36 import com.android.odp.module.common.PackageUtils;
37 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
38 import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice;
39 import com.android.ondevicepersonalization.services.Flags;
40 import com.android.ondevicepersonalization.services.FlagsFactory;
41 import com.android.ondevicepersonalization.services.OdpServiceException;
42 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors;
43 import com.android.ondevicepersonalization.services.data.DataAccessPermission;
44 import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl;
45 import com.android.ondevicepersonalization.services.data.events.EventState;
46 import com.android.ondevicepersonalization.services.data.events.EventsDao;
47 import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus;
48 import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper;
49 import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor;
50 import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo;
51 import com.android.ondevicepersonalization.services.process.ProcessRunner;
52 import com.android.ondevicepersonalization.services.process.ProcessRunnerFactory;
53 import com.android.ondevicepersonalization.services.util.AllowListUtils;
54 import com.android.ondevicepersonalization.services.util.StatsUtils;
55 
56 import com.google.common.util.concurrent.FluentFuture;
57 import com.google.common.util.concurrent.FutureCallback;
58 import com.google.common.util.concurrent.Futures;
59 import com.google.common.util.concurrent.ListenableFuture;
60 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
61 
62 import java.util.Objects;
63 import java.util.concurrent.CountDownLatch;
64 import java.util.concurrent.TimeUnit;
65 import java.util.concurrent.TimeoutException;
66 
67 /** Implementation of ExampleStoreService for OnDevicePersonalization */
68 public final class OdpExampleStoreService extends ExampleStoreService {
69 
70     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
71     private static final String TAG = OdpExampleStoreService.class.getSimpleName();
72     private static final String TASK_NAME = "ExampleStore";
73 
74     static class Injector {
getClock()75         Clock getClock() {
76             return MonotonicClock.getInstance();
77         }
78 
getFlags()79         Flags getFlags() {
80             return FlagsFactory.getFlags();
81         }
82 
getScheduledExecutor()83         ListeningScheduledExecutorService getScheduledExecutor() {
84             return OnDevicePersonalizationExecutors.getScheduledExecutor();
85         }
86 
getProcessRunner()87         ProcessRunner getProcessRunner() {
88             return ProcessRunnerFactory.getProcessRunner();
89         }
90     }
91 
92     private final Injector mInjector = new Injector();
93 
94     /** Generates a unique task identifier from the given strings */
getTaskIdentifier(String populationName, String taskId)95     public static String getTaskIdentifier(String populationName, String taskId) {
96         return populationName + "_" + taskId;
97     }
98 
99     /** Generates a unique task identifier from the given strings */
getTaskIdentifier( String populationName, String taskId, String collectionUri)100     public static String getTaskIdentifier(
101             String populationName, String taskId, String collectionUri) {
102         return populationName + "_" + taskId + "_" + collectionUri;
103     }
104 
isCollectionUriPresent(String collectionUri)105     private static boolean isCollectionUriPresent(String collectionUri) {
106         return collectionUri != null && !collectionUri.isEmpty();
107     }
108 
109     @Override
startQuery(@onNull Bundle params, @NonNull QueryCallback callback)110     public void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback) {
111         try {
112             long startTime = mInjector.getClock().currentTimeMillis();
113             ContextData contextData =
114                     ContextData.fromByteArray(
115                             Objects.requireNonNull(
116                                     params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA)));
117             String packageName = contextData.getPackageName();
118             String ownerClassName = contextData.getClassName();
119             String populationName =
120                     Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
121             String taskId = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_ID));
122             String collectionUri = params.getString(ClientConstants.EXTRA_COLLECTION_URI);
123             int eligibilityMinExample =
124                     params.getInt(ClientConstants.EXTRA_ELIGIBILITY_MIN_EXAMPLE);
125 
126             EventsDao eventDao = EventsDao.getInstance(getContext());
127 
128             boolean privacyStatusEligible = true;
129 
130             if (!UserPrivacyStatus.getInstance().isMeasurementEnabled()) {
131                 privacyStatusEligible = false;
132                 sLogger.w(TAG + ": Measurement control is not given.");
133                 StatsUtils.writeServiceRequestMetrics(
134                         Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
135                         packageName,
136                         null,
137                         mInjector.getClock(),
138                         Constants.STATUS_PERSONALIZATION_DISABLED,
139                         startTime);
140             }
141 
142             // Cancel job if on longer valid. This is written to the table during scheduling
143             // via {@link FederatedComputeServiceImpl} and deleted either during cancel or
144             // during maintenance for uninstalled packages.
145             ComponentName owner = ComponentName.createRelative(packageName, ownerClassName);
146             EventState eventStatePopulation = eventDao.getEventState(populationName, owner);
147             if (eventStatePopulation == null) {
148                 StatsUtils.writeServiceRequestMetrics(
149                         Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
150                         packageName,
151                         null,
152                         mInjector.getClock(),
153                         Constants.STATUS_KEY_NOT_FOUND,
154                         startTime);
155             }
156             if (!privacyStatusEligible || eventStatePopulation == null) {
157                 sLogger.w("Job was either cancelled or package was uninstalled");
158                 // Cancel job.
159                 FederatedComputeManager FCManager =
160                         getContext().getSystemService(FederatedComputeManager.class);
161                 if (FCManager == null) {
162                     sLogger.e(TAG + ": Failed to get FederatedCompute Service");
163                     callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
164                     return;
165                 }
166                 FCManager.cancel(
167                         owner,
168                         populationName,
169                         OnDevicePersonalizationExecutors.getBackgroundExecutor(),
170                         new OutcomeReceiver<Object, Exception>() {
171                             @Override
172                             public void onResult(Object result) {
173                                 sLogger.d(TAG + ": Successfully canceled job");
174                                 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
175                             }
176 
177                             @Override
178                             public void onError(Exception error) {
179                                 sLogger.e(TAG + ": Error while cancelling job", error);
180                                 OutcomeReceiver.super.onError(error);
181                                 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
182                             }
183                         });
184                 return;
185             }
186 
187             // Get resumptionToken
188             EventState eventState =
189                     eventDao.getEventState(
190                             isCollectionUriPresent(collectionUri)
191                                     ? getTaskIdentifier(populationName, taskId, collectionUri)
192                                     : getTaskIdentifier(populationName, taskId),
193                             owner);
194             byte[] resumptionToken = null;
195             if (eventState != null) {
196                 resumptionToken = eventState.getToken();
197             }
198 
199             TrainingExamplesInputParcel.Builder input =
200                     new TrainingExamplesInputParcel.Builder()
201                             .setResumptionToken(resumptionToken)
202                             .setPopulationName(populationName)
203                             .setTaskName(taskId);
204             if (isCollectionUriPresent(collectionUri)) {
205                 input.setCollectionName(collectionUri);
206             }
207 
208             String className =
209                     AppManifestConfigHelper.getServiceNameFromOdpSettings(
210                             getContext(), packageName);
211             ListenableFuture<IsolatedServiceInfo> loadFuture =
212                     mInjector
213                             .getProcessRunner()
214                             .loadIsolatedService(
215                                     TASK_NAME,
216                                     ComponentName.createRelative(packageName, className));
217             ListenableFuture<Bundle> resultFuture =
218                     FluentFuture.from(loadFuture)
219                             .transformAsync(
220                                     result ->
221                                             executeOnTrainingExamples(
222                                                     result, input.build(), packageName),
223                                     OnDevicePersonalizationExecutors.getBackgroundExecutor())
224                             .withTimeout(
225                                     mInjector.getFlags().getIsolatedServiceDeadlineSeconds(),
226                                     TimeUnit.SECONDS,
227                                     mInjector.getScheduledExecutor());
228 
229             CountDownLatch latch = new CountDownLatch(1);
230             Futures.addCallback(
231                     resultFuture,
232                     new FutureCallback<Bundle>() {
233                         @Override
234                         public void onSuccess(Bundle result) {
235                             int status = Constants.STATUS_SUCCESS;
236                             try {
237                                 TrainingExamplesOutputParcel trainingExamplesOutputParcel =
238                                         result.getParcelable(
239                                                 Constants.EXTRA_RESULT,
240                                                 TrainingExamplesOutputParcel.class);
241                                 if (trainingExamplesOutputParcel == null) {
242                                     status = Constants.STATUS_NAME_NOT_FOUND;
243                                     callback.onStartQueryFailure(
244                                             ClientConstants.STATUS_INTERNAL_ERROR);
245                                     return;
246                                 }
247                                 OdpParceledListSlice<TrainingExampleRecord>
248                                         trainingExampleRecordList =
249                                                 trainingExamplesOutputParcel
250                                                         .getTrainingExampleRecords();
251                                 if (trainingExampleRecordList == null
252                                         || trainingExampleRecordList.getList().isEmpty()) {
253                                     status = Constants.STATUS_SUCCESS_EMPTY_RESULT;
254                                     callback.onStartQueryFailure(
255                                             ClientConstants.STATUS_NOT_ENOUGH_DATA);
256                                 } else if (trainingExampleRecordList.getList().size()
257                                         < eligibilityMinExample) {
258                                     sLogger.d(TAG + ": not enough examples, requires %d got %d",
259                                             eligibilityMinExample,
260                                             trainingExampleRecordList.getList().size());
261                                     status = Constants.STATUS_SUCCESS_NOT_ENOUGH_DATA;
262                                     callback.onStartQueryFailure(
263                                             ClientConstants.STATUS_NOT_ENOUGH_DATA);
264                                 } else {
265                                     callback.onStartQuerySuccess(
266                                             OdpExampleStoreIteratorFactory.getInstance()
267                                                     .createIterator(
268                                                             trainingExampleRecordList.getList()));
269                                 }
270                             } finally {
271                                 latch.countDown();
272                                 StatsUtils.writeServiceRequestMetrics(
273                                         Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
274                                         packageName,
275                                         result,
276                                         mInjector.getClock(),
277                                         status,
278                                         startTime);
279                             }
280                         }
281 
282                         @Override
283                         public void onFailure(Throwable t) {
284                             latch.countDown();
285                             int status = Constants.STATUS_INTERNAL_ERROR;
286                             if (t instanceof TimeoutException) {
287                                 status = Constants.STATUS_TIMEOUT;
288                             } else if (t instanceof OdpServiceException exp) {
289                                 if (exp.getCause() instanceof IsolatedServiceException
290                                         && isLogIsolatedServiceErrorCodeNonAggregatedAllowed(
291                                                 packageName)) {
292                                     status = ((IsolatedServiceException) exp.getCause())
293                                             .getErrorCode();
294                                 } else {
295                                     status = exp.getErrorCode();
296                                 }
297                             }
298                             sLogger.w(t, "%s : Request failed.", TAG);
299                             StatsUtils.writeServiceRequestMetrics(
300                                     Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
301                                     packageName,
302                                     null,
303                                     mInjector.getClock(),
304                                     status,
305                                     startTime);
306                             callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
307                         }
308                     },
309                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
310 
311             var unused =
312                     Futures.whenAllComplete(loadFuture, resultFuture)
313                             .callAsync(
314                                     () -> {
315                                         try {
316                                             latch.await();
317                                         } catch (InterruptedException e) {
318                                             sLogger.e(e, "%s : Interrupted while "
319                                                     + "waiting for transaction complete", TAG);
320                                         }
321                                         return mInjector
322                                                 .getProcessRunner()
323                                                 .unloadIsolatedService(loadFuture.get());
324                                     },
325                                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
326         } catch (Throwable e) {
327             sLogger.e(e, "%s : Start query failed.", TAG);
328             StatsUtils.writeServiceRequestMetrics(
329                     Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
330                     Constants.STATUS_INTERNAL_ERROR);
331             callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
332         }
333     }
334 
executeOnTrainingExamples( IsolatedServiceInfo isolatedServiceInfo, TrainingExamplesInputParcel exampleInput, String packageName)335     private ListenableFuture<Bundle> executeOnTrainingExamples(
336             IsolatedServiceInfo isolatedServiceInfo,
337             TrainingExamplesInputParcel exampleInput,
338             String packageName) {
339         sLogger.d(TAG + ": executeOnTrainingExamples() started.");
340         Bundle serviceParams = new Bundle();
341         serviceParams.putParcelable(Constants.EXTRA_INPUT, exampleInput);
342         String serviceClass =
343                 AppManifestConfigHelper.getServiceNameFromOdpSettings(getContext(), packageName);
344         DataAccessServiceImpl binder =
345                 new DataAccessServiceImpl(
346                         ComponentName.createRelative(packageName, serviceClass),
347                         getContext(),
348                         // ODP provides accurate user signal in training flow, so we disable write
349                         // access of databases to prevent leak.
350                         /* localDataPermission */ DataAccessPermission.READ_ONLY,
351                         /* eventDataPermission */ DataAccessPermission.READ_ONLY);
352         serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder);
353         UserDataAccessor userDataAccessor = new UserDataAccessor();
354         UserData userData;
355         // By default, we don't provide platform data for federated learning flow.
356         if (isPlatformDataProvided(packageName)) {
357             userData = userDataAccessor.getUserDataWithAppInstall();
358         } else {
359             userData = userDataAccessor.getUserData();
360         }
361         serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData);
362         return mInjector
363                 .getProcessRunner()
364                 .runIsolatedService(
365                         isolatedServiceInfo, Constants.OP_TRAINING_EXAMPLE, serviceParams);
366     }
367 
368     // used for tests to provide mock/real implementation of context.
getContext()369     private Context getContext() {
370         return this.getApplicationContext();
371     }
372 
isPlatformDataProvided(String packageName)373     private boolean isPlatformDataProvided(String packageName) {
374         try {
375             return AllowListUtils.isAllowListed(
376                     packageName,
377                     PackageUtils.getCertDigest(getContext(), packageName),
378                     mInjector.getFlags().getDefaultPlatformDataForExecuteAllowlist());
379         } catch (Exception e) {
380             sLogger.d(TAG + ": allow list error", e);
381             return false;
382         }
383     }
384 
isLogIsolatedServiceErrorCodeNonAggregatedAllowed(String packageName)385     private boolean isLogIsolatedServiceErrorCodeNonAggregatedAllowed(String packageName) {
386         try {
387             return AllowListUtils.isAllowListed(
388                     packageName,
389                     null,
390                     mInjector.getFlags().getLogIsolatedServiceErrorCodeNonAggregatedAllowlist());
391         } catch (Exception e) {
392             sLogger.d(e, TAG + ": check isLogIsolatedServiceErrorCodeNonAggregatedAllowed error");
393             return false;
394         }
395     }
396 }
397