• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2020 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.binder;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.junit.Assert.fail;
21 
22 import android.content.Context;
23 import androidx.test.core.app.ApplicationProvider;
24 import androidx.test.ext.junit.runners.AndroidJUnit4;
25 import com.google.common.base.Function;
26 import com.google.protobuf.Empty;
27 import io.grpc.CallOptions;
28 import io.grpc.ManagedChannel;
29 import io.grpc.Metadata;
30 import io.grpc.MethodDescriptor;
31 import io.grpc.Server;
32 import io.grpc.ServerCall;
33 import io.grpc.ServerCallHandler;
34 import io.grpc.ServerInterceptor;
35 import io.grpc.ServerServiceDefinition;
36 import io.grpc.Status;
37 import io.grpc.StatusRuntimeException;
38 import io.grpc.protobuf.lite.ProtoLiteUtils;
39 import io.grpc.stub.ClientCalls;
40 import io.grpc.stub.ServerCalls;
41 import java.util.ArrayList;
42 import java.util.HashMap;
43 import java.util.List;
44 import java.util.Map;
45 import javax.annotation.Nullable;
46 import org.junit.After;
47 import org.junit.Before;
48 import org.junit.Test;
49 import org.junit.runner.RunWith;
50 
51 @RunWith(AndroidJUnit4.class)
52 public final class BinderSecurityTest {
53   private final Context appContext = ApplicationProvider.getApplicationContext();
54 
55   String[] serviceNames = new String[] {"foo", "bar", "baz"};
56   List<ServerServiceDefinition> serviceDefinitions = new ArrayList<>();
57 
58   @Nullable ManagedChannel channel;
59   Map<String, MethodDescriptor<Empty, Empty>> methods = new HashMap<>();
60   List<MethodDescriptor<Empty, Empty>> calls = new ArrayList<>();
61   CountingServerInterceptor countingServerInterceptor;
62 
63   @Before
setupServiceDefinitionsAndMethods()64   public void setupServiceDefinitionsAndMethods() {
65     MethodDescriptor.Marshaller<Empty> marshaller =
66         ProtoLiteUtils.marshaller(Empty.getDefaultInstance());
67     for (String serviceName : serviceNames) {
68       ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(serviceName);
69       for (int i = 0; i < 2; i++) {
70         // Add two methods to the service.
71         String name = serviceName + "/method" + i;
72         MethodDescriptor<Empty, Empty> method =
73             MethodDescriptor.newBuilder(marshaller, marshaller)
74                 .setFullMethodName(name)
75                 .setType(MethodDescriptor.MethodType.UNARY)
76                 .build();
77         ServerCallHandler<Empty, Empty> callHandler =
78             ServerCalls.asyncUnaryCall(
79                 (req, respObserver) -> {
80                   calls.add(method);
81                   respObserver.onNext(req);
82                   respObserver.onCompleted();
83                 });
84         builder.addMethod(method, callHandler);
85         methods.put(name, method);
86       }
87       serviceDefinitions.add(builder.build());
88     }
89     countingServerInterceptor = new CountingServerInterceptor();
90   }
91 
92   @After
tearDown()93   public void tearDown() throws Exception {
94     if (channel != null) {
95       channel.shutdownNow();
96     }
97     HostServices.awaitServiceShutdown();
98   }
99 
createChannel()100   private void createChannel() throws Exception {
101     createChannel(SecurityPolicies.serverInternalOnly(), SecurityPolicies.internalOnly());
102   }
103 
createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy)104   private void createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy)
105       throws Exception {
106     AndroidComponentAddress addr = HostServices.allocateService(appContext);
107     HostServices.configureService(addr,
108         HostServices.serviceParamsBuilder()
109           .setServerFactory((service, receiver) -> buildServer(addr, receiver, serverPolicy))
110           .build());
111 
112     channel =
113         BinderChannelBuilder.forAddress(addr, appContext)
114             .securityPolicy(channelPolicy)
115             .build();
116   }
117 
buildServer( AndroidComponentAddress listenAddr, IBinderReceiver receiver, ServerSecurityPolicy serverPolicy)118   private Server buildServer(
119       AndroidComponentAddress listenAddr,
120       IBinderReceiver receiver,
121       ServerSecurityPolicy serverPolicy) {
122     BinderServerBuilder serverBuilder = BinderServerBuilder.forAddress(listenAddr, receiver);
123     serverBuilder.securityPolicy(serverPolicy);
124     serverBuilder.intercept(countingServerInterceptor);
125 
126     for (ServerServiceDefinition serviceDefinition : serviceDefinitions) {
127       serverBuilder.addService(serviceDefinition);
128     }
129     return serverBuilder.build();
130   }
131 
assertCallSuccess(MethodDescriptor<Empty, Empty> method)132   private void assertCallSuccess(MethodDescriptor<Empty, Empty> method) {
133     assertThat(
134             ClientCalls.blockingUnaryCall(
135                 channel, method, CallOptions.DEFAULT, Empty.getDefaultInstance()))
136         .isNotNull();
137   }
138 
assertCallFailure(MethodDescriptor<Empty, Empty> method, Status status)139   private void assertCallFailure(MethodDescriptor<Empty, Empty> method, Status status) {
140     try {
141       ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null);
142       fail();
143     } catch (StatusRuntimeException sre) {
144       assertThat(sre.getStatus().getCode()).isEqualTo(status.getCode());
145     }
146   }
147 
148   @Test
testAllowedCall()149   public void testAllowedCall() throws Exception {
150     createChannel();
151     assertThat(methods).isNotEmpty();
152     for (MethodDescriptor<Empty, Empty> method : methods.values()) {
153       assertCallSuccess(method);
154     }
155   }
156 
157   @Test
testServerDisllowsCalls()158   public void testServerDisllowsCalls() throws Exception {
159     createChannel(
160         ServerSecurityPolicy.newBuilder()
161             .servicePolicy("foo", policy((uid) -> false))
162             .servicePolicy("bar", policy((uid) -> false))
163             .servicePolicy("baz", policy((uid) -> false))
164             .build(),
165         SecurityPolicies.internalOnly());
166     assertThat(methods).isNotEmpty();
167     for (MethodDescriptor<Empty, Empty> method : methods.values()) {
168       assertCallFailure(method, Status.PERMISSION_DENIED);
169     }
170   }
171 
172   @Test
testClientDoesntTrustServer()173   public void testClientDoesntTrustServer() throws Exception {
174     createChannel(SecurityPolicies.serverInternalOnly(), policy((uid) -> false));
175     assertThat(methods).isNotEmpty();
176     for (MethodDescriptor<Empty, Empty> method : methods.values()) {
177       assertCallFailure(method, Status.PERMISSION_DENIED);
178     }
179   }
180 
181   @Test
testPerServicePolicy()182   public void testPerServicePolicy() throws Exception {
183     createChannel(
184         ServerSecurityPolicy.newBuilder()
185             .servicePolicy("foo", policy((uid) -> true))
186             .servicePolicy("bar", policy((uid) -> false))
187             .build(),
188         SecurityPolicies.internalOnly());
189 
190     assertThat(methods).isNotEmpty();
191     for (MethodDescriptor<Empty, Empty> method : methods.values()) {
192       if (method.getServiceName().equals("bar")) {
193         assertCallFailure(method, Status.PERMISSION_DENIED);
194       } else {
195         assertCallSuccess(method);
196       }
197     }
198   }
199 
200   @Test
testSecurityInterceptorIsClosestToTransport()201   public void testSecurityInterceptorIsClosestToTransport() throws Exception {
202     createChannel(
203         ServerSecurityPolicy.newBuilder()
204             .servicePolicy("foo", policy((uid) -> true))
205             .servicePolicy("bar", policy((uid) -> false))
206             .servicePolicy("baz", policy((uid) -> false))
207             .build(),
208         SecurityPolicies.internalOnly());
209     assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(0);
210     for (MethodDescriptor<Empty, Empty> method : methods.values()) {
211       try {
212         ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null);
213       } catch (StatusRuntimeException sre) {
214         // Ignore.
215       }
216     }
217     // Only the foo calls should have made it to the user interceptor.
218     assertThat(countingServerInterceptor.numInterceptedCalls).isEqualTo(2);
219   }
220 
policy(Function<Integer, Boolean> func)221   private static SecurityPolicy policy(Function<Integer, Boolean> func) {
222     return new SecurityPolicy() {
223       @Override
224       public Status checkAuthorization(int uid) {
225         return func.apply(uid) ? Status.OK : Status.PERMISSION_DENIED;
226       }
227     };
228   }
229 
230   private final class CountingServerInterceptor implements ServerInterceptor {
231     int numInterceptedCalls;
232 
233     @Override
234     public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
235         ServerCall<ReqT, RespT> call,
236         Metadata headers,
237         ServerCallHandler<ReqT, RespT> next) {
238       numInterceptedCalls += 1;
239       return next.startCall(call, headers);
240     }
241   }
242 }
243