/* * Copyright 2024 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package android.adservices.ondevicepersonalization; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import android.adservices.ondevicepersonalization.aidl.IDataAccessService; import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback; import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService; import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService; import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback; import android.adservices.ondevicepersonalization.aidl.IIsolatedService; import android.adservices.ondevicepersonalization.aidl.IIsolatedServiceCallback; import android.content.Context; import android.federatedcompute.common.TrainingOptions; import android.net.Uri; import android.os.Bundle; import android.os.PersistableBundle; import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import androidx.test.filters.SmallTest; import com.android.federatedcompute.internal.util.AbstractServiceBinder; import com.android.ondevicepersonalization.internal.util.ByteArrayParceledSlice; import com.android.ondevicepersonalization.internal.util.PersistableBundleUtils; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import java.util.Arrays; import java.util.Collection; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @SmallTest @RunWith(Parameterized.class) public class IsolatedServiceExceptionSafetyTest { private final Context mContext = ApplicationProvider.getApplicationContext(); private IIsolatedService mIsolatedService; private static AbstractServiceBinder sServiceBinder; private int mCallbackErrorCode; private int mIsolatedServiceErrorCode; private byte[] mSerializedExceptionInfo; private CountDownLatch mLatch; @Parameterized.Parameter(0) public String operation; @Parameterized.Parameters public static Collection data() { return Arrays.asList( new Object[][] { {RuntimeException.class.getName()}, {NullPointerException.class.getName()}, {IllegalArgumentException.class.getName()} }); } @Before public void setUp() throws Exception { if (sServiceBinder == null) { sServiceBinder = AbstractServiceBinder.getIsolatedServiceBinderByServiceName( mContext, "android.adservices.ondevicepersonalization." + "IsolatedServiceExceptionSafetyTestImpl", mContext.getPackageName(), "testIsolatedProcess", 0, IIsolatedService.Stub::asInterface); } mIsolatedService = sServiceBinder.getService(Runnable::run); mLatch = new CountDownLatch(1); } @After public void tearDown() { mIsolatedService = null; mCallbackErrorCode = 0; } @AfterClass public static void tearDownClass() { sServiceBinder.unbindFromService(); } @Test public void testOnRequestExceptions() throws Exception { PersistableBundle appParams = new PersistableBundle(); appParams.putString("ex", operation); ExecuteInputParcel input = new ExecuteInputParcel.Builder() .setAppPackageName("com.testapp") .setSerializedAppParams(new ByteArrayParceledSlice( PersistableBundleUtils.toByteArray(appParams))) .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); params.putBinder( Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, new TestFederatedComputeService()); params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService()); mIsolatedService.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } @Test public void testOnDownloadExceptions() throws Exception { DownloadInputParcel input = new DownloadInputParcel.Builder() .setDataAccessServiceBinder(new TestDataAccessService(operation)) .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); params.putBinder( Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER, new TestFederatedComputeService()); mIsolatedService.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } @Test public void testOnRender() throws Exception { RenderInputParcel input = new RenderInputParcel.Builder() .setRenderingConfig( new RenderingConfig.Builder().addKey(operation).build()) .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); mIsolatedService.onRequest(Constants.OP_RENDER, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } @Test public void testOnEvent() throws Exception { PersistableBundle appParams = new PersistableBundle(); appParams.putString("ex", operation); Bundle params = new Bundle(); params.putParcelable( Constants.EXTRA_INPUT, new EventInputParcel.Builder().setParameters(appParams).build()); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService()); mIsolatedService.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } @Test public void testOnTrainingExamples() throws Exception { TrainingExamplesInputParcel input = new TrainingExamplesInputParcel.Builder() .setPopulationName("") .setTaskName(operation) .setResumptionToken(new byte[] {0}) .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); mIsolatedService.onRequest( Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } @Test public void testOnWebTrigger() throws Exception { WebTriggerInputParcel input = new WebTriggerInputParcel.Builder( Uri.parse("http://desturl"), operation, new byte[] {1, 2, 3}) .build(); Bundle params = new Bundle(); params.putParcelable(Constants.EXTRA_INPUT, input); params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService()); params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService()); mIsolatedService.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback()); assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS)); assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode); } class TestServiceCallback extends IIsolatedServiceCallback.Stub { @Override public void onSuccess(Bundle result) { mLatch.countDown(); } @Override public void onError( int errorCode, int isolatedServiceErrorCode, byte[] serializedExceptionInfo) { mCallbackErrorCode = errorCode; mIsolatedServiceErrorCode = isolatedServiceErrorCode; mSerializedExceptionInfo = serializedExceptionInfo; mLatch.countDown(); } } static class TestDataAccessService extends IDataAccessService.Stub { String mOp; TestDataAccessService(String operation) { this.mOp = operation; } TestDataAccessService() { mOp = null; } @Override public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) { // pass parameters for onDownloadCompleted testing if (mOp != null) { Bundle bndl = new Bundle(); bndl.putParcelable( Constants.EXTRA_RESULT, new ByteArrayParceledSlice(mOp.getBytes())); try { callback.onSuccess(bndl); } catch (RemoteException e) { throw new RuntimeException(e); } } } @Override public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {} } static class TestFederatedComputeService extends IFederatedComputeService.Stub { @Override public void schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback) {} public void cancel(String populationName, IFederatedComputeCallback callback) {} } static class TestIsolatedModelService extends IIsolatedModelService.Stub { @Override public void runInference(Bundle params, IIsolatedModelServiceCallback callback) {} } }