1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services.training; 18 19 import static android.federatedcompute.common.ClientConstants.RESULT_HANDLING_SERVICE_ACTION; 20 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; 21 import static android.federatedcompute.common.ClientConstants.STATUS_TRAINING_FAILED; 22 23 import android.content.Context; 24 import android.federatedcompute.aidl.IFederatedComputeCallback; 25 import android.federatedcompute.aidl.IResultHandlingService; 26 import android.federatedcompute.common.ClientConstants; 27 import android.os.Bundle; 28 29 import com.android.federatedcompute.internal.util.AbstractServiceBinder; 30 import com.android.federatedcompute.internal.util.LogUtil; 31 import com.android.federatedcompute.services.data.FederatedTrainingTask; 32 import com.android.federatedcompute.services.training.util.ComputationResult; 33 34 import com.google.common.annotations.VisibleForTesting; 35 import com.google.common.util.concurrent.Futures; 36 import com.google.common.util.concurrent.ListenableFuture; 37 38 import java.util.concurrent.ArrayBlockingQueue; 39 import java.util.concurrent.BlockingQueue; 40 import java.util.concurrent.TimeUnit; 41 42 /** 43 * A helper class for binding to client implemented ResultHandlingService and trigger handleResult. 44 */ 45 public class ResultCallbackHelper { 46 private static final String TAG = ResultCallbackHelper.class.getSimpleName(); 47 private static final long RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS = 10; 48 49 /** The outcome of the result handling. */ 50 public enum CallbackResult { 51 // Result handling succeeded, and the task completed. 52 SUCCESS, 53 // Result handling failed. 54 FAIL, 55 // Result handling succeeded, but the task needs to resume. 56 NEEDS_RESUME, 57 } 58 59 private final Context mContext; 60 private AbstractServiceBinder<IResultHandlingService> mResultHandlingServiceBinder; 61 ResultCallbackHelper(Context context)62 public ResultCallbackHelper(Context context) { 63 this.mContext = context.getApplicationContext(); 64 } 65 66 /** 67 * Publishes the training result and example list to client implemented ResultHandlingService. 68 */ callHandleResult( String taskId, FederatedTrainingTask task, ComputationResult result)69 public ListenableFuture<CallbackResult> callHandleResult( 70 String taskId, FederatedTrainingTask task, ComputationResult result) { 71 Bundle input = new Bundle(); 72 input.putString(ClientConstants.EXTRA_POPULATION_NAME, task.populationName()); 73 input.putString(ClientConstants.EXTRA_TASK_ID, taskId); 74 input.putByteArray(ClientConstants.EXTRA_CONTEXT_DATA, task.contextData()); 75 input.putInt( 76 ClientConstants.EXTRA_COMPUTATION_RESULT, 77 result.isResultSuccess() ? STATUS_SUCCESS : STATUS_TRAINING_FAILED); 78 input.putParcelableArrayList( 79 ClientConstants.EXTRA_EXAMPLE_CONSUMPTION_LIST, result.getExampleConsumptionList()); 80 81 try { 82 IResultHandlingService resultHandlingService = 83 getResultHandlingService(task.appPackageName()); 84 if (resultHandlingService == null) { 85 LogUtil.e( 86 TAG, 87 "ResultHandlingService binding died. population name: " 88 + task.populationName()); 89 return Futures.immediateFuture(CallbackResult.FAIL); 90 } 91 92 BlockingQueue<Integer> asyncResult = new ArrayBlockingQueue<>(1); 93 resultHandlingService.handleResult( 94 input, 95 new IFederatedComputeCallback.Stub() { 96 @Override 97 public void onSuccess() { 98 asyncResult.add(STATUS_SUCCESS); 99 } 100 101 @Override 102 public void onFailure(int errorCode) { 103 asyncResult.add(errorCode); 104 } 105 }); 106 int statusCode = 107 asyncResult.poll( 108 RESULT_HANDLING_SERVICE_CALLBACK_TIMEOUT_SECS, TimeUnit.SECONDS); 109 CallbackResult callbackResult = 110 statusCode == STATUS_SUCCESS ? CallbackResult.SUCCESS : CallbackResult.FAIL; 111 return Futures.immediateFuture(callbackResult); 112 } catch (Exception e) { 113 LogUtil.e( 114 TAG, 115 e, 116 "ResultHandlingService binding died. population name: %s", 117 task.populationName()); 118 // We publish result to client app with best effort and should not crash flow. 119 return Futures.immediateFuture(CallbackResult.FAIL); 120 } finally { 121 unbindFromResultHandlingService(); 122 } 123 } 124 125 @VisibleForTesting getResultHandlingService(String appPackageName)126 IResultHandlingService getResultHandlingService(String appPackageName) { 127 mResultHandlingServiceBinder = 128 AbstractServiceBinder.getServiceBinderByIntent( 129 this.mContext, 130 RESULT_HANDLING_SERVICE_ACTION, 131 appPackageName, 132 IResultHandlingService.Stub::asInterface); 133 return mResultHandlingServiceBinder.getService(Runnable::run); 134 } 135 136 @VisibleForTesting unbindFromResultHandlingService()137 void unbindFromResultHandlingService() { 138 mResultHandlingServiceBinder.unbindFromService(); 139 } 140 } 141