• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.federatedcompute;
18 
19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR;
20 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS;
21 import static android.federatedcompute.common.TrainingInterval.SCHEDULING_MODE_ONE_TIME;
22 
23 import static com.google.common.truth.Truth.assertThat;
24 
25 import static org.junit.Assert.assertTrue;
26 
27 import android.federatedcompute.aidl.IFederatedComputeCallback;
28 import android.federatedcompute.aidl.IResultHandlingService;
29 import android.federatedcompute.common.ExampleConsumption;
30 import android.federatedcompute.common.TrainingInterval;
31 import android.federatedcompute.common.TrainingOptions;
32 
33 import androidx.test.ext.junit.runners.AndroidJUnit4;
34 
35 import com.google.common.collect.ImmutableList;
36 
37 import org.junit.Before;
38 import org.junit.Test;
39 import org.junit.runner.RunWith;
40 
41 import java.util.List;
42 import java.util.concurrent.CountDownLatch;
43 import java.util.function.Consumer;
44 
45 @RunWith(AndroidJUnit4.class)
46 public final class ResultHandlingServiceTest {
47     private static final String TEST_POPULATION = "testPopulation";
48     private static final int JOB_ID = 12345;
49     private static final byte[] SELECTION_CRITERIA = new byte[] {10, 0, 1};
50     private static final TrainingOptions TRAINING_OPTIONS =
51             new TrainingOptions.Builder()
52                     .setPopulationName(TEST_POPULATION)
53                     .setJobSchedulerJobId(JOB_ID)
54                     .setTrainingInterval(
55                             new TrainingInterval.Builder()
56                                     .setSchedulingMode(SCHEDULING_MODE_ONE_TIME)
57                                     .build())
58                     .build();
59     private static final ImmutableList<ExampleConsumption> EXAMPLE_CONSUMPTIONS =
60             ImmutableList.of(
61                     new ExampleConsumption.Builder()
62                             .setCollectionName("collection")
63                             .setExampleCount(100)
64                             .setSelectionCriteria(SELECTION_CRITERIA)
65                             .build());
66 
67     private boolean mSuccess = false;
68     private boolean mHandleResultCalled = false;
69     private int mErrorCode = 0;
70     private final CountDownLatch mLatch = new CountDownLatch(1);
71 
72     private IResultHandlingService mBinder;
73     private final TestResultHandlingService mTestResultHandlingService =
74             new TestResultHandlingService();
75 
76     @Before
doBeforeEachTest()77     public void doBeforeEachTest() {
78         mTestResultHandlingService.onCreate();
79         mBinder = IResultHandlingService.Stub.asInterface(mTestResultHandlingService.onBind(null));
80     }
81 
82     @Test
testHandleResult_success()83     public void testHandleResult_success() throws Exception {
84         mBinder.handleResult(
85                 TRAINING_OPTIONS, true, EXAMPLE_CONSUMPTIONS, new TestFederatedComputeCallback());
86 
87         mLatch.await();
88         assertTrue(mHandleResultCalled);
89         assertTrue(mSuccess);
90     }
91 
92     @Test
testHandleResult_failure()93     public void testHandleResult_failure() throws Exception {
94         mBinder.handleResult(TRAINING_OPTIONS, true, null, new TestFederatedComputeCallback());
95 
96         mLatch.await();
97         assertTrue(mHandleResultCalled);
98         assertThat(mErrorCode).isEqualTo(STATUS_INTERNAL_ERROR);
99     }
100 
101     class TestResultHandlingService extends ResultHandlingService {
102         @Override
handleResult( TrainingOptions trainingOptions, boolean success, List<ExampleConsumption> exampleConsumptionList, Consumer<Integer> callback)103         public void handleResult(
104                 TrainingOptions trainingOptions,
105                 boolean success,
106                 List<ExampleConsumption> exampleConsumptionList,
107                 Consumer<Integer> callback) {
108             mHandleResultCalled = true;
109             if (exampleConsumptionList == null || exampleConsumptionList.isEmpty()) {
110                 callback.accept(STATUS_INTERNAL_ERROR);
111                 return;
112             }
113             callback.accept(STATUS_SUCCESS);
114         }
115     }
116 
117     class TestFederatedComputeCallback extends IFederatedComputeCallback.Stub {
118         @Override
onSuccess()119         public void onSuccess() {
120             mSuccess = true;
121             mLatch.countDown();
122         }
123 
124         @Override
onFailure(int errorCode)125         public void onFailure(int errorCode) {
126             mErrorCode = errorCode;
127             mLatch.countDown();
128         }
129     }
130 }
131