• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.training;
18 
19 import static android.federatedcompute.common.ClientConstants.RESULT_HANDLING_SERVICE_ACTION;
20 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS;
21 import static android.federatedcompute.common.ClientConstants.STATUS_TRAINING_FAILED;
22 
23 import android.content.Context;
24 import android.federatedcompute.aidl.IFederatedComputeCallback;
25 import android.federatedcompute.aidl.IResultHandlingService;
26 import android.federatedcompute.common.ClientConstants;
27 import android.os.Bundle;
28 
29 import com.android.federatedcompute.internal.util.AbstractServiceBinder;
30 import com.android.federatedcompute.internal.util.LogUtil;
31 import com.android.federatedcompute.services.data.FederatedTrainingTask;
32 import com.android.federatedcompute.services.training.util.ComputationResult;
33 
34 import com.google.common.annotations.VisibleForTesting;
35 import com.google.common.util.concurrent.Futures;
36 import com.google.common.util.concurrent.ListenableFuture;
37 
38 import java.util.concurrent.ArrayBlockingQueue;
39 import java.util.concurrent.BlockingQueue;
40 import java.util.concurrent.TimeUnit;
41 
42 /**
43  * A helper class for binding to client implemented ResultHandlingService and trigger handleResult.
44  */
45 public class ResultCallbackHelper {
46     private static final String TAG = ResultCallbackHelper.class.getSimpleName();
47     private static final long RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS = 10;
48 
49     /** The outcome of the result handling. */
50     public enum CallbackResult {
51         // Result handling succeeded, and the task completed.
52         SUCCESS,
53         // Result handling failed.
54         FAIL,
55         // Result handling succeeded, but the task needs to resume.
56         NEEDS_RESUME,
57     }
58 
59     private final Context mContext;
60     private AbstractServiceBinder<IResultHandlingService> mResultHandlingServiceBinder;
61 
ResultCallbackHelper(Context context)62     public ResultCallbackHelper(Context context) {
63         this.mContext = context.getApplicationContext();
64     }
65 
66     /**
67      * Publishes the training result and example list to client implemented ResultHandlingService.
68      */
callHandleResult( String taskId, FederatedTrainingTask task, ComputationResult result)69     public ListenableFuture<CallbackResult> callHandleResult(
70             String taskId, FederatedTrainingTask task, ComputationResult result) {
71         Bundle input = new Bundle();
72         input.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName());
73         input.putString(ClientConstants.EXTRA_TASK_ID, taskId);
74         input.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData());
75         input.putInt(
76                 ClientConstants.EXTRA_COMPUTATION_RESULT,
77                 result.isResultSuccess() ? STATUS_SUCCESS : STATUS_TRAINING_FAILED);
78         input.putParcelableArrayList(
79                 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, result.getExampleConsumptionList());
80 
81         try {
82             IResultHandlingService resultHandlingService =
83                     getResultHandlingService(task.appPackageName());
84             if (resultHandlingService == null) {
85                 LogUtil.e(
86                         TAG,
87                         "ResultHandlingService binding died. population name: "
88                                 + task.populationName());
89                 return Futures.immediateFuture(CallbackResult.FAIL);
90             }
91 
92             BlockingQueue<Integer> asyncResult = new ArrayBlockingQueue<>(1);
93             resultHandlingService.handleResult(
94                     input,
95                     new IFederatedComputeCallback.Stub() {
96                         @Override
97                         public void onSuccess() {
98                             asyncResult.add(STATUS_SUCCESS);
99                         }
100 
101                         @Override
102                         public void onFailure(int errorCode) {
103                             asyncResult.add(errorCode);
104                         }
105                     });
106             int statusCode =
107                     asyncResult.poll(
108                             RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS, TimeUnit.SECONDS);
109             CallbackResult callbackResult =
110                     statusCode == STATUS_SUCCESS ? CallbackResult.SUCCESS : CallbackResult.FAIL;
111             return Futures.immediateFuture(callbackResult);
112         } catch (Exception e) {
113             LogUtil.e(
114                     TAG,
115                     e,
116                     "ResultHandlingService binding died. population name: %s",
117                     task.populationName());
118             // We publish result to client app with best effort and should not crash flow.
119             return Futures.immediateFuture(CallbackResult.FAIL);
120         } finally {
121             unbindFromResultHandlingService();
122         }
123     }
124 
125     @VisibleForTesting
getResultHandlingService(String appPackageName)126     IResultHandlingService getResultHandlingService(String appPackageName) {
127         mResultHandlingServiceBinder =
128                 AbstractServiceBinder.getServiceBinderByIntent(
129                         this.mContext,
130                         RESULT_HANDLING_SERVICE_ACTION,
131                         appPackageName,
132                         IResultHandlingService.Stub::asInterface);
133         return mResultHandlingServiceBinder.getService(Runnable::run);
134     }
135 
136     @VisibleForTesting
unbindFromResultHandlingService()137     void unbindFromResultHandlingService() {
138         mResultHandlingServiceBinder.unbindFromService();
139     }
140 }
141