1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package android.federatedcompute; 17 18 import static org.junit.Assert.assertThrows; 19 import static org.mockito.ArgumentMatchers.any; 20 import static org.mockito.ArgumentMatchers.anyInt; 21 import static org.mockito.ArgumentMatchers.isNull; 22 import static org.mockito.Mockito.doAnswer; 23 import static org.mockito.Mockito.doThrow; 24 import static org.mockito.Mockito.spy; 25 import static org.mockito.Mockito.times; 26 import static org.mockito.Mockito.verify; 27 import static org.mockito.Mockito.when; 28 29 import android.content.ComponentName; 30 import android.content.Context; 31 import android.content.ContextWrapper; 32 import android.content.Intent; 33 import android.content.ServiceConnection; 34 import android.content.pm.PackageManager; 35 import android.content.pm.ResolveInfo; 36 import android.content.pm.ServiceInfo; 37 import android.federatedcompute.aidl.IFederatedComputeCallback; 38 import android.federatedcompute.aidl.IFederatedComputeService; 39 import android.federatedcompute.common.ScheduleFederatedComputeRequest; 40 import android.federatedcompute.common.TrainingOptions; 41 import android.os.IBinder; 42 import android.os.OutcomeReceiver; 43 import android.os.RemoteException; 44 45 import androidx.test.core.app.ApplicationProvider; 46 47 import org.junit.Before; 48 import org.junit.Test; 49 import org.junit.runner.RunWith; 50 import org.junit.runners.Parameterized; 51 import org.mockito.Mock; 52 import org.mockito.MockitoAnnotations; 53 54 import java.util.Arrays; 55 import java.util.Collection; 56 import java.util.List; 57 import java.util.concurrent.Executor; 58 import java.util.concurrent.Executors; 59 60 @RunWith(Parameterized.class) 61 public class FederatedComputeManagerTest { 62 63 private final Context mContext = 64 spy(new MyTestContext(ApplicationProvider.getApplicationContext())); 65 66 private static final ComponentName OWNER_COMPONENT = 67 ComponentName.createRelative("com.android.package.name", "com.android.class.name"); 68 69 @Parameterized.Parameter(0) 70 public String scenario; 71 72 @Parameterized.Parameter(1) 73 public ScheduleFederatedComputeRequest request; 74 75 @Parameterized.Parameter(2) 76 public String populationName; 77 78 @Parameterized.Parameter(3) 79 public IFederatedComputeService iFederatedComputeService; 80 81 @Mock private PackageManager mMockPackageManager; 82 @Mock private IBinder mMockIBinder; 83 @Mock private IFederatedComputeService mMockIService; 84 85 @Parameterized.Parameters data()86 public static Collection<Object[]> data() { 87 return Arrays.asList( 88 new Object[][] { 89 {"schedule-allNull", null, null, null}, 90 { 91 "schedule-default-iService", 92 new ScheduleFederatedComputeRequest.Builder() 93 .setTrainingOptions(new TrainingOptions.Builder().build()) 94 .build(), 95 null, 96 new IFederatedComputeService.Default() 97 }, 98 { 99 "schedule-mockIService-RemoteException", 100 new ScheduleFederatedComputeRequest.Builder() 101 .setTrainingOptions(new TrainingOptions.Builder().build()) 102 .build(), 103 null, 104 null /* mock will be returned */ 105 }, 106 { 107 "schedule-mockIService-onSuccess", 108 new ScheduleFederatedComputeRequest.Builder() 109 .setTrainingOptions(new TrainingOptions.Builder().build()) 110 .build(), 111 null, 112 null /* mock will be returned */ 113 }, 114 { 115 "schedule-mockIService-onFailure", 116 new ScheduleFederatedComputeRequest.Builder() 117 .setTrainingOptions(new TrainingOptions.Builder().build()) 118 .build(), 119 null, 120 null /* mock will be returned */ 121 }, 122 { 123 "schedule-unavailable-iService", 124 new ScheduleFederatedComputeRequest.Builder() 125 .setTrainingOptions(new TrainingOptions.Builder().build()) 126 .build(), 127 null, 128 null /* throw exception when getting instance */ 129 }, 130 {"cancel-allNull", null, null, null}, 131 { 132 "cancel-default-iService", 133 null, 134 "testPopulation", 135 new IFederatedComputeService.Default() 136 }, 137 { 138 "cancel-mockIService-RemoteException", 139 null, 140 "testPopulation", 141 null /* mock will be returned */ 142 }, 143 { 144 "cancel-mockIService-onSuccess", 145 null, 146 "testPopulation", 147 null /* mock will be returned */ 148 }, 149 { 150 "cancel-mockIService-onFailure", 151 null, 152 "testPopulation", 153 null /* mock will be returned */ 154 }, 155 { 156 "cancel-unavailable-iService", 157 null, 158 "testPopulation", 159 null /* throw exception when getting instance */ 160 }, 161 }); 162 } 163 164 @Before setUp()165 public void setUp() { 166 MockitoAnnotations.initMocks(this); 167 ResolveInfo resolveInfo = new ResolveInfo(); 168 ServiceInfo serviceInfo = new ServiceInfo(); 169 serviceInfo.name = "TestName"; 170 serviceInfo.packageName = "com.android.federatedcompute.services"; 171 resolveInfo.serviceInfo = serviceInfo; 172 when(mMockPackageManager.queryIntentServices(any(), anyInt())) 173 .thenReturn(List.of(resolveInfo)); 174 when(mMockIBinder.queryLocalInterface(any())).thenReturn(iFederatedComputeService); 175 } 176 177 @Test testScheduleFederatedCompute()178 public void testScheduleFederatedCompute() throws RemoteException { 179 FederatedComputeManager manager = new FederatedComputeManager(mContext); 180 OutcomeReceiver<Object, Exception> spyCallback; 181 182 switch (scenario) { 183 case "schedule-allNull": 184 assertThrows( 185 NullPointerException.class, () -> manager.schedule(request, null, null)); 186 break; 187 case "schedule-default-iService": 188 manager.schedule(request, Executors.newSingleThreadExecutor(), null); 189 break; 190 case "schedule-mockIService-RemoteException": 191 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 192 doThrow(new RemoteException()).when(mMockIService).schedule(any(), any(), any()); 193 spyCallback = spy(new MyTestCallback()); 194 195 manager.schedule(request, Runnable::run, spyCallback); 196 197 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 198 verify(spyCallback, times(1)).onError(any(RemoteException.class)); 199 verify(mContext, times(1)).unbindService(any()); 200 break; 201 case "schedule-mockIService-onSuccess": 202 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 203 doAnswer( 204 invocation -> { 205 IFederatedComputeCallback federatedComputeCallback = 206 invocation.getArgument(2); 207 federatedComputeCallback.onSuccess(); 208 return null; 209 }) 210 .when(mMockIService) 211 .schedule(any(), any(), any()); 212 spyCallback = spy(new MyTestCallback()); 213 214 manager.schedule(request, Runnable::run, spyCallback); 215 216 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 217 verify(spyCallback, times(1)).onResult(isNull()); 218 verify(mContext, times(1)).unbindService(any()); 219 break; 220 case "schedule-mockIService-onFailure": 221 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 222 doAnswer( 223 invocation -> { 224 IFederatedComputeCallback federatedComputeCallback = 225 invocation.getArgument(2); 226 federatedComputeCallback.onFailure(1); 227 return null; 228 }) 229 .when(mMockIService) 230 .schedule(any(), any(), any()); 231 spyCallback = spy(new MyTestCallback()); 232 233 manager.schedule(request, Runnable::run, spyCallback); 234 235 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 236 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); 237 verify(mContext, times(1)).unbindService(any()); 238 break; 239 case "schedule-unavailable-iService": 240 when(mMockIBinder.queryLocalInterface(any())).thenThrow(RuntimeException.class); 241 spyCallback = spy(new MyTestCallback()); 242 243 manager.schedule(request, Runnable::run, spyCallback); 244 245 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 246 verify(spyCallback, times(1)).onError(any(RuntimeException.class)); 247 verify(mContext, times(1)).unbindService(any()); 248 break; 249 case "cancel-allNull": 250 assertThrows( 251 NullPointerException.class, 252 () -> 253 manager.cancel( 254 OWNER_COMPONENT, 255 populationName, 256 null, 257 null)); 258 break; 259 case "cancel-default-iService": 260 manager.cancel( 261 OWNER_COMPONENT, 262 populationName, 263 Executors.newSingleThreadExecutor(), 264 null); 265 break; 266 case "cancel-mockIService-RemoteException": 267 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 268 doThrow(new RemoteException()) 269 .when(mMockIService) 270 .cancel(any(), any(), any()); 271 spyCallback = spy(new MyTestCallback()); 272 273 manager.cancel( 274 OWNER_COMPONENT, 275 populationName, 276 Runnable::run, 277 spyCallback); 278 279 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 280 verify(spyCallback, times(1)).onError(any(RemoteException.class)); 281 verify(mContext, times(1)).unbindService(any()); 282 break; 283 case "cancel-mockIService-onSuccess": 284 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 285 doAnswer( 286 invocation -> { 287 IFederatedComputeCallback federatedComputeCallback = 288 invocation.getArgument(2); 289 federatedComputeCallback.onSuccess(); 290 return null; 291 }) 292 .when(mMockIService) 293 .cancel(any(), any(), any()); 294 spyCallback = spy(new MyTestCallback()); 295 296 manager.cancel( 297 OWNER_COMPONENT, 298 populationName, 299 Runnable::run, 300 spyCallback); 301 302 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 303 verify(spyCallback, times(1)).onResult(isNull()); 304 verify(mContext, times(1)).unbindService(any()); 305 break; 306 case "cancel-mockIService-onFailure": 307 when(mMockIBinder.queryLocalInterface(any())).thenReturn(mMockIService); 308 doAnswer( 309 invocation -> { 310 IFederatedComputeCallback federatedComputeCallback = 311 invocation.getArgument(2); 312 federatedComputeCallback.onFailure(1); 313 return null; 314 }) 315 .when(mMockIService) 316 .cancel(any(), any(), any()); 317 spyCallback = spy(new MyTestCallback()); 318 319 manager.cancel( 320 OWNER_COMPONENT, 321 populationName, 322 Runnable::run, 323 spyCallback); 324 325 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 326 verify(spyCallback, times(1)).onError(any(FederatedComputeException.class)); 327 verify(mContext, times(1)).unbindService(any()); 328 break; 329 case "cancel-unavailable-iService": 330 when(mMockIBinder.queryLocalInterface(any())).thenThrow(RuntimeException.class); 331 spyCallback = spy(new MyTestCallback()); 332 333 manager.cancel( 334 OWNER_COMPONENT, 335 populationName, 336 Runnable::run, 337 spyCallback); 338 339 verify(mContext, times(1)).bindService(any(), anyInt(), any(), any()); 340 verify(spyCallback, times(1)).onError(any(RuntimeException.class)); 341 verify(mContext, times(1)).unbindService(any()); 342 break; 343 default: 344 break; 345 } 346 } 347 348 public class MyTestContext extends ContextWrapper { 349 MyTestContext(Context context)350 MyTestContext(Context context) { 351 super(context); 352 } 353 354 @Override getPackageManager()355 public PackageManager getPackageManager() { 356 return mMockPackageManager != null ? mMockPackageManager : super.getPackageManager(); 357 } 358 359 @Override bindService( Intent service, int flags, Executor executor, ServiceConnection conn)360 public boolean bindService( 361 Intent service, int flags, Executor executor, ServiceConnection conn) { 362 executor.execute( 363 () -> { 364 conn.onServiceConnected(null, mMockIBinder); 365 }); 366 return true; 367 } 368 unbindService(ServiceConnection conn)369 public void unbindService(ServiceConnection conn) {} 370 } 371 372 public class MyTestCallback implements OutcomeReceiver<Object, Exception> { 373 374 @Override onResult(Object o)375 public void onResult(Object o) {} 376 377 @Override onError(Exception error)378 public void onError(Exception error) { 379 OutcomeReceiver.super.onError(error); 380 } 381 } 382 } 383