1 /* 2 * Copyright 2020 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.binder; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.junit.Assert.fail; 21 22 import android.content.Context; 23 import androidx.test.core.app.ApplicationProvider; 24 import androidx.test.ext.junit.runners.AndroidJUnit4; 25 import com.google.common.base.Function; 26 import com.google.protobuf.Empty; 27 import io.grpc.CallOptions; 28 import io.grpc.ManagedChannel; 29 import io.grpc.Metadata; 30 import io.grpc.MethodDescriptor; 31 import io.grpc.Server; 32 import io.grpc.ServerCall; 33 import io.grpc.ServerCallHandler; 34 import io.grpc.ServerInterceptor; 35 import io.grpc.ServerServiceDefinition; 36 import io.grpc.Status; 37 import io.grpc.StatusRuntimeException; 38 import io.grpc.protobuf.lite.ProtoLiteUtils; 39 import io.grpc.stub.ClientCalls; 40 import io.grpc.stub.ServerCalls; 41 import java.util.ArrayList; 42 import java.util.HashMap; 43 import java.util.List; 44 import java.util.Map; 45 import javax.annotation.Nullable; 46 import org.junit.After; 47 import org.junit.Before; 48 import org.junit.Test; 49 import org.junit.runner.RunWith; 50 51 @RunWith(AndroidJUnit4.class) 52 public final class BinderSecurityTest { 53 private final Context appContext = ApplicationProvider.getApplicationContext(); 54 55 String[] serviceNames = new String[] {"foo", "bar", "baz"}; 56 List<ServerServiceDefinition> serviceDefinitions = new ArrayList<>(); 57 58 @Nullable ManagedChannel channel; 59 Map<String, MethodDescriptor<Empty, Empty>> methods = new HashMap<>(); 60 List<MethodDescriptor<Empty, Empty>> calls = new ArrayList<>(); 61 CountingServerInterceptor countingServerInterceptor; 62 63 @Before setupServiceDefinitionsAndMethods()64 public void setupServiceDefinitionsAndMethods() { 65 MethodDescriptor.Marshaller<Empty> marshaller = 66 ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); 67 for (String serviceName : serviceNames) { 68 ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(serviceName); 69 for (int i = 0; i < 2; i++) { 70 // Add two methods to the service. 71 String name = serviceName + "/method" + i; 72 MethodDescriptor<Empty, Empty> method = 73 MethodDescriptor.newBuilder(marshaller, marshaller) 74 .setFullMethodName(name) 75 .setType(MethodDescriptor.MethodType.UNARY) 76 .build(); 77 ServerCallHandler<Empty, Empty> callHandler = 78 ServerCalls.asyncUnaryCall( 79 (req, respObserver) -> { 80 calls.add(method); 81 respObserver.onNext(req); 82 respObserver.onCompleted(); 83 }); 84 builder.addMethod(method, callHandler); 85 methods.put(name, method); 86 } 87 serviceDefinitions.add(builder.build()); 88 } 89 countingServerInterceptor = new CountingServerInterceptor(); 90 } 91 92 @After tearDown()93 public void tearDown() throws Exception { 94 if (channel != null) { 95 channel.shutdownNow(); 96 } 97 HostServices.awaitServiceShutdown(); 98 } 99 createChannel()100 private void createChannel() throws Exception { 101 createChannel(SecurityPolicies.serverInternalOnly(), SecurityPolicies.internalOnly()); 102 } 103 createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy)104 private void createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy) 105 throws Exception { 106 AndroidComponentAddress addr = HostServices.allocateService(appContext); 107 HostServices.configureService(addr, 108 HostServices.serviceParamsBuilder() 109 .setServerFactory((service, receiver) -> buildServer(addr, receiver, serverPolicy)) 110 .build()); 111 112 channel = 113 BinderChannelBuilder.forAddress(addr, appContext) 114 .securityPolicy(channelPolicy) 115 .build(); 116 } 117 buildServer( AndroidComponentAddress listenAddr, IBinderReceiver receiver, ServerSecurityPolicy serverPolicy)118 private Server buildServer( 119 AndroidComponentAddress listenAddr, 120 IBinderReceiver receiver, 121 ServerSecurityPolicy serverPolicy) { 122 BinderServerBuilder serverBuilder = BinderServerBuilder.forAddress(listenAddr, receiver); 123 serverBuilder.securityPolicy(serverPolicy); 124 serverBuilder.intercept(countingServerInterceptor); 125 126 for (ServerServiceDefinition serviceDefinition : serviceDefinitions) { 127 serverBuilder.addService(serviceDefinition); 128 } 129 return serverBuilder.build(); 130 } 131 assertCallSuccess(MethodDescriptor<Empty, Empty> method)132 private void assertCallSuccess(MethodDescriptor<Empty, Empty> method) { 133 assertThat( 134 ClientCalls.blockingUnaryCall( 135 channel, method, CallOptions.DEFAULT, Empty.getDefaultInstance())) 136 .isNotNull(); 137 } 138 assertCallFailure(MethodDescriptor<Empty, Empty> method, Status status)139 private void assertCallFailure(MethodDescriptor<Empty, Empty> method, Status status) { 140 try { 141 ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); 142 fail(); 143 } catch (StatusRuntimeException sre) { 144 assertThat(sre.getStatus().getCode()).isEqualTo(status.getCode()); 145 } 146 } 147 148 @Test testAllowedCall()149 public void testAllowedCall() throws Exception { 150 createChannel(); 151 assertThat(methods).isNotEmpty(); 152 for (MethodDescriptor<Empty, Empty> method : methods.values()) { 153 assertCallSuccess(method); 154 } 155 } 156 157 @Test testServerDisllowsCalls()158 public void testServerDisllowsCalls() throws Exception { 159 createChannel( 160 ServerSecurityPolicy.newBuilder() 161 .servicePolicy("foo", policy((uid) -> false)) 162 .servicePolicy("bar", policy((uid) -> false)) 163 .servicePolicy("baz", policy((uid) -> false)) 164 .build(), 165 SecurityPolicies.internalOnly()); 166 assertThat(methods).isNotEmpty(); 167 for (MethodDescriptor<Empty, Empty> method : methods.values()) { 168 assertCallFailure(method, Status.PERMISSION_DENIED); 169 } 170 } 171 172 @Test testClientDoesntTrustServer()173 public void testClientDoesntTrustServer() throws Exception { 174 createChannel(SecurityPolicies.serverInternalOnly(), policy((uid) -> false)); 175 assertThat(methods).isNotEmpty(); 176 for (MethodDescriptor<Empty, Empty> method : methods.values()) { 177 assertCallFailure(method, Status.PERMISSION_DENIED); 178 } 179 } 180 181 @Test testPerServicePolicy()182 public void testPerServicePolicy() throws Exception { 183 createChannel( 184 ServerSecurityPolicy.newBuilder() 185 .servicePolicy("foo", policy((uid) -> true)) 186 .servicePolicy("bar", policy((uid) -> false)) 187 .build(), 188 SecurityPolicies.internalOnly()); 189 190 assertThat(methods).isNotEmpty(); 191 for (MethodDescriptor<Empty, Empty> method : methods.values()) { 192 if (method.getServiceName().equals("bar")) { 193 assertCallFailure(method, Status.PERMISSION_DENIED); 194 } else { 195 assertCallSuccess(method); 196 } 197 } 198 } 199 200 @Test testSecurityInterceptorIsClosestToTransport()201 public void testSecurityInterceptorIsClosestToTransport() throws Exception { 202 createChannel( 203 ServerSecurityPolicy.newBuilder() 204 .servicePolicy("foo", policy((uid) -> true)) 205 .servicePolicy("bar", policy((uid) -> false)) 206 .servicePolicy("baz", policy((uid) -> false)) 207 .build(), 208 SecurityPolicies.internalOnly()); 209 assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(0); 210 for (MethodDescriptor<Empty, Empty> method : methods.values()) { 211 try { 212 ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); 213 } catch (StatusRuntimeException sre) { 214 // Ignore. 215 } 216 } 217 // Only the foo calls should have made it to the user interceptor. 218 assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(2); 219 } 220 policy(Function<Integer, Boolean> func)221 private static SecurityPolicy policy(Function<Integer, Boolean> func) { 222 return new SecurityPolicy() { 223 @Override 224 public Status checkAuthorization(int uid) { 225 return func.apply(uid) ? Status.OK : Status.PERMISSION_DENIED; 226 } 227 }; 228 } 229 230 private final class CountingServerInterceptor implements ServerInterceptor { 231 int numInterceptedCalls; 232 233 @Override 234 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 235 ServerCall<ReqT, RespT> call, 236 Metadata headers, 237 ServerCallHandler<ReqT, RespT> next) { 238 numInterceptedCalls += 1; 239 return next.startCall(call, headers); 240 } 241 } 242 } 243