• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc;
18 
19 import static com.google.common.collect.Iterables.getOnlyElement;
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertSame;
22 import static org.junit.Assert.assertTrue;
23 import static org.mockito.AdditionalAnswers.delegatesTo;
24 import static org.mockito.ArgumentMatchers.same;
25 import static org.mockito.Mockito.mock;
26 import static org.mockito.Mockito.times;
27 import static org.mockito.Mockito.verify;
28 import static org.mockito.Mockito.verifyNoInteractions;
29 import static org.mockito.Mockito.verifyNoMoreInteractions;
30 
31 import io.grpc.MethodDescriptor.Marshaller;
32 import io.grpc.MethodDescriptor.MethodType;
33 import io.grpc.ServerCall.Listener;
34 import io.grpc.internal.NoopServerCall;
35 import java.io.ByteArrayInputStream;
36 import java.io.InputStream;
37 import java.util.ArrayList;
38 import java.util.Arrays;
39 import java.util.List;
40 import org.junit.After;
41 import org.junit.Before;
42 import org.junit.Rule;
43 import org.junit.Test;
44 import org.junit.rules.ExpectedException;
45 import org.junit.runner.RunWith;
46 import org.junit.runners.JUnit4;
47 import org.mockito.ArgumentMatchers;
48 import org.mockito.Mock;
49 import org.mockito.Mockito;
50 import org.mockito.junit.MockitoJUnit;
51 import org.mockito.junit.MockitoRule;
52 
53 /** Unit tests for {@link ServerInterceptors}. */
54 @RunWith(JUnit4.class)
55 public class ServerInterceptorsTest {
56   @Rule
57   public final MockitoRule mocks = MockitoJUnit.rule();
58 
59   @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
60   @Rule
61   public final ExpectedException thrown = ExpectedException.none();
62 
63   @Mock
64   private Marshaller<String> requestMarshaller;
65 
66   @Mock
67   private Marshaller<Integer> responseMarshaller;
68 
69   @Mock
70   private ServerCallHandler<String, Integer> handler;
71 
72   @Mock
73   private ServerCall.Listener<String> listener;
74 
75   private MethodDescriptor<String, Integer> flowMethod;
76 
77   private ServerCall<String, Integer> call = new NoopServerCall<>();
78 
79   private ServerServiceDefinition serviceDefinition;
80 
81   private final Metadata headers = new Metadata();
82 
83   /** Set up for test. */
84   @Before
setUp()85   public void setUp() {
86     flowMethod = MethodDescriptor.<String, Integer>newBuilder()
87         .setType(MethodType.UNKNOWN)
88         .setFullMethodName("basic/flow")
89         .setRequestMarshaller(requestMarshaller)
90         .setResponseMarshaller(responseMarshaller)
91         .build();
92 
93     Mockito.when(
94             handler.startCall(
95                 ArgumentMatchers.<ServerCall<String, Integer>>any(),
96                 ArgumentMatchers.<Metadata>any()))
97             .thenReturn(listener);
98 
99     serviceDefinition = ServerServiceDefinition.builder(new ServiceDescriptor("basic", flowMethod))
100         .addMethod(flowMethod, handler).build();
101   }
102 
103   /** Final checks for all tests. */
104   @After
makeSureExpectedMocksUnused()105   public void makeSureExpectedMocksUnused() {
106     verifyNoInteractions(requestMarshaller);
107     verifyNoInteractions(responseMarshaller);
108     verifyNoInteractions(listener);
109   }
110 
111   @Test
npeForNullServiceDefinition()112   public void npeForNullServiceDefinition() {
113     ServerServiceDefinition serviceDef = null;
114     List<ServerInterceptor> interceptors = Arrays.asList();
115     thrown.expect(NullPointerException.class);
116     ServerInterceptors.intercept(serviceDef, interceptors);
117   }
118 
119   @Test
npeForNullInterceptorList()120   public void npeForNullInterceptorList() {
121     thrown.expect(NullPointerException.class);
122     ServerInterceptors.intercept(serviceDefinition, (List<ServerInterceptor>) null);
123   }
124 
125   @Test
npeForNullInterceptor()126   public void npeForNullInterceptor() {
127     List<ServerInterceptor> interceptors = Arrays.asList((ServerInterceptor) null);
128     thrown.expect(NullPointerException.class);
129     ServerInterceptors.intercept(serviceDefinition, interceptors);
130   }
131 
132   @Test
noop()133   public void noop() {
134     assertSame(serviceDefinition,
135         ServerInterceptors.intercept(serviceDefinition, Arrays.<ServerInterceptor>asList()));
136   }
137 
138   @Test
multipleInvocationsOfHandler()139   public void multipleInvocationsOfHandler() {
140     ServerInterceptor interceptor =
141         mock(ServerInterceptor.class, delegatesTo(new NoopInterceptor()));
142     ServerServiceDefinition intercepted
143         = ServerInterceptors.intercept(serviceDefinition, Arrays.asList(interceptor));
144     assertSame(listener,
145         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
146     verify(interceptor).interceptCall(same(call), same(headers), anyCallHandler());
147     verify(handler).startCall(call, headers);
148     verifyNoMoreInteractions(interceptor, handler);
149 
150     assertSame(listener,
151         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
152     verify(interceptor, times(2))
153         .interceptCall(same(call), same(headers), anyCallHandler());
154     verify(handler, times(2)).startCall(call, headers);
155     verifyNoMoreInteractions(interceptor, handler);
156   }
157 
158   @Test
correctHandlerCalled()159   public void correctHandlerCalled() {
160     @SuppressWarnings("unchecked")
161     ServerCallHandler<String, Integer> handler2 = mock(ServerCallHandler.class);
162     MethodDescriptor<String, Integer> flowMethod2 =
163         flowMethod.toBuilder().setFullMethodName("basic/flow2").build();
164     serviceDefinition = ServerServiceDefinition.builder(
165         new ServiceDescriptor("basic", flowMethod, flowMethod2))
166         .addMethod(flowMethod, handler)
167         .addMethod(flowMethod2, handler2).build();
168     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
169         serviceDefinition, Arrays.<ServerInterceptor>asList(new NoopInterceptor()));
170     getMethod(intercepted, "basic/flow").getServerCallHandler().startCall(call, headers);
171     verify(handler).startCall(call, headers);
172     verifyNoMoreInteractions(handler);
173     verifyNoMoreInteractions(handler2);
174 
175     getMethod(intercepted, "basic/flow2").getServerCallHandler().startCall(call, headers);
176     verify(handler2).startCall(call, headers);
177     verifyNoMoreInteractions(handler);
178     verifyNoMoreInteractions(handler2);
179   }
180 
181   @Test
callNextTwice()182   public void callNextTwice() {
183     ServerInterceptor interceptor = new ServerInterceptor() {
184       @Override
185       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
186           ServerCall<ReqT, RespT> call,
187           Metadata headers,
188           ServerCallHandler<ReqT, RespT> next) {
189         // Calling next twice is permitted, although should only rarely be useful.
190         assertSame(listener, next.startCall(call, headers));
191         return next.startCall(call, headers);
192       }
193     };
194     ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition,
195         interceptor);
196     assertSame(listener,
197         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
198     verify(handler, times(2)).startCall(same(call), same(headers));
199     verifyNoMoreInteractions(handler);
200   }
201 
202   @Test
ordered()203   public void ordered() {
204     final List<String> order = new ArrayList<>();
205     handler = new ServerCallHandler<String, Integer>() {
206           @Override
207           public ServerCall.Listener<String> startCall(
208               ServerCall<String, Integer> call,
209               Metadata headers) {
210             order.add("handler");
211             return listener;
212           }
213         };
214     ServerInterceptor interceptor1 = new ServerInterceptor() {
215           @Override
216           public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
217               ServerCall<ReqT, RespT> call,
218               Metadata headers,
219               ServerCallHandler<ReqT, RespT> next) {
220             order.add("i1");
221             return next.startCall(call, headers);
222           }
223         };
224     ServerInterceptor interceptor2 = new ServerInterceptor() {
225           @Override
226           public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
227               ServerCall<ReqT, RespT> call,
228               Metadata headers,
229               ServerCallHandler<ReqT, RespT> next) {
230             order.add("i2");
231             return next.startCall(call, headers);
232           }
233         };
234     ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
235         new ServiceDescriptor("basic", flowMethod))
236         .addMethod(flowMethod, handler).build();
237     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
238         serviceDefinition, Arrays.asList(interceptor1, interceptor2));
239     assertSame(listener,
240         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
241     assertEquals(Arrays.asList("i2", "i1", "handler"), order);
242   }
243 
244   @Test
orderedForward()245   public void orderedForward() {
246     final List<String> order = new ArrayList<>();
247     handler = new ServerCallHandler<String, Integer>() {
248       @Override
249       public ServerCall.Listener<String> startCall(
250           ServerCall<String, Integer> call,
251           Metadata headers) {
252         order.add("handler");
253         return listener;
254       }
255     };
256     ServerInterceptor interceptor1 = new ServerInterceptor() {
257       @Override
258       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
259           ServerCall<ReqT, RespT> call,
260           Metadata headers,
261           ServerCallHandler<ReqT, RespT> next) {
262         order.add("i1");
263         return next.startCall(call, headers);
264       }
265     };
266     ServerInterceptor interceptor2 = new ServerInterceptor() {
267       @Override
268       public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
269           ServerCall<ReqT, RespT> call,
270           Metadata headers,
271           ServerCallHandler<ReqT, RespT> next) {
272         order.add("i2");
273         return next.startCall(call, headers);
274       }
275     };
276     ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder(
277         new ServiceDescriptor("basic", flowMethod))
278         .addMethod(flowMethod, handler).build();
279     ServerServiceDefinition intercepted = ServerInterceptors.interceptForward(
280         serviceDefinition, interceptor1, interceptor2);
281     assertSame(listener,
282         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
283     assertEquals(Arrays.asList("i1", "i2", "handler"), order);
284   }
285 
286   @Test
argumentsPassed()287   public void argumentsPassed() {
288     final ServerCall<String, Integer> call2 = new NoopServerCall<>();
289     @SuppressWarnings("unchecked")
290     final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class);
291 
292     ServerInterceptor interceptor = new ServerInterceptor() {
293         @SuppressWarnings("unchecked") // Lot's of casting for no benefit.  Not intended use.
294         @Override
295         public <R1, R2> ServerCall.Listener<R1> interceptCall(
296             ServerCall<R1, R2> call,
297             Metadata headers,
298             ServerCallHandler<R1, R2> next) {
299           assertSame(call, ServerInterceptorsTest.this.call);
300           assertSame(listener,
301               next.startCall((ServerCall<R1, R2>)call2, headers));
302           return (ServerCall.Listener<R1>) listener2;
303         }
304       };
305     ServerServiceDefinition intercepted = ServerInterceptors.intercept(
306         serviceDefinition, Arrays.asList(interceptor));
307     assertSame(listener2,
308         getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers));
309     verify(handler).startCall(call2, headers);
310   }
311 
312   @Test
313   @SuppressWarnings("unchecked")
typedMarshalledMessages()314   public void typedMarshalledMessages() {
315     final List<String> order = new ArrayList<>();
316     Marshaller<Holder> marshaller = new Marshaller<Holder>() {
317       @Override
318       public InputStream stream(Holder value) {
319         return value.get();
320       }
321 
322       @Override
323       public Holder parse(InputStream stream) {
324         return new Holder(stream);
325       }
326     };
327 
328     ServerCallHandler<Holder, Holder> handler2 = new ServerCallHandler<Holder, Holder>() {
329       @Override
330       public Listener<Holder> startCall(final ServerCall<Holder, Holder> call,
331                                         final Metadata headers) {
332         return new Listener<Holder>() {
333           @Override
334           public void onMessage(Holder message) {
335             order.add("handler");
336             call.sendMessage(message);
337           }
338         };
339       }
340     };
341 
342     MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder()
343         .setType(MethodType.UNKNOWN)
344         .setFullMethodName("basic/wrapped")
345         .setRequestMarshaller(marshaller)
346         .setResponseMarshaller(marshaller)
347         .build();
348     ServerServiceDefinition serviceDef = ServerServiceDefinition.builder(
349         new ServiceDescriptor("basic", wrappedMethod))
350         .addMethod(wrappedMethod, handler2).build();
351 
352     ServerInterceptor interceptor1 = new ServerInterceptor() {
353       @Override
354       public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
355                                                         Metadata headers,
356                                                         ServerCallHandler<ReqT, RespT> next) {
357         ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
358             .SimpleForwardingServerCall<ReqT, RespT>(call) {
359           @Override
360           public void sendMessage(RespT message) {
361             order.add("i1sendMessage");
362             assertTrue(message instanceof Holder);
363             super.sendMessage(message);
364           }
365         };
366 
367         ServerCall.Listener<ReqT> originalListener = next
368             .startCall(interceptedCall, headers);
369         return new ForwardingServerCallListener
370             .SimpleForwardingServerCallListener<ReqT>(originalListener) {
371           @Override
372           public void onMessage(ReqT message) {
373             order.add("i1onMessage");
374             assertTrue(message instanceof Holder);
375             super.onMessage(message);
376           }
377         };
378       }
379     };
380 
381     ServerInterceptor interceptor2 = new ServerInterceptor() {
382       @Override
383       public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
384                                                         Metadata headers,
385                                                         ServerCallHandler<ReqT, RespT> next) {
386         ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall
387             .SimpleForwardingServerCall<ReqT, RespT>(call) {
388           @Override
389           public void sendMessage(RespT message) {
390             order.add("i2sendMessage");
391             assertTrue(message instanceof InputStream);
392             super.sendMessage(message);
393           }
394         };
395 
396         ServerCall.Listener<ReqT> originalListener = next
397             .startCall(interceptedCall, headers);
398         return new ForwardingServerCallListener
399             .SimpleForwardingServerCallListener<ReqT>(originalListener) {
400           @Override
401           public void onMessage(ReqT message) {
402             order.add("i2onMessage");
403             assertTrue(message instanceof InputStream);
404             super.onMessage(message);
405           }
406         };
407       }
408     };
409 
410     ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDef, interceptor1);
411     ServerServiceDefinition inputStreamMessageService = ServerInterceptors
412         .useInputStreamMessages(intercepted);
413     ServerServiceDefinition intercepted2 = ServerInterceptors
414         .intercept(inputStreamMessageService, interceptor2);
415     ServerMethodDefinition<InputStream, InputStream> serverMethod =
416         (ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped");
417     ServerCall<InputStream, InputStream> call2 = new NoopServerCall<>();
418     byte[] bytes = {};
419     serverMethod
420         .getServerCallHandler()
421         .startCall(call2, headers)
422         .onMessage(new ByteArrayInputStream(bytes));
423     assertEquals(
424         Arrays.asList("i2onMessage", "i1onMessage", "handler", "i1sendMessage", "i2sendMessage"),
425         order);
426   }
427 
428   /**
429    * Tests the ServerInterceptors#useMarshalledMessages()} with two marshallers. Makes sure that
430    * on incoming request the request marshaller's stream method is called and on response the
431    * response marshaller's parse method is called
432    */
433   @Test
434   @SuppressWarnings("unchecked")
distinctMarshallerForRequestAndResponse()435   public void distinctMarshallerForRequestAndResponse() {
436     final List<String> requestFlowOrder = new ArrayList<>();
437 
438     final Marshaller<String> requestMarshaller = new Marshaller<String>() {
439       @Override
440       public InputStream stream(String value) {
441         requestFlowOrder.add("RequestStream");
442         return null;
443       }
444 
445       @Override
446       public String parse(InputStream stream) {
447         requestFlowOrder.add("RequestParse");
448         return null;
449       }
450     };
451     final Marshaller<String> responseMarshaller = new Marshaller<String>() {
452       @Override
453       public InputStream stream(String value) {
454         requestFlowOrder.add("ResponseStream");
455         return null;
456       }
457 
458       @Override
459       public String parse(InputStream stream) {
460         requestFlowOrder.add("ResponseParse");
461         return null;
462       }
463     };
464     final Marshaller<Holder> dummyMarshaller = new Marshaller<Holder>() {
465       @Override
466       public InputStream stream(Holder value) {
467         return value.get();
468       }
469 
470       @Override
471       public Holder parse(InputStream stream) {
472         return new Holder(stream);
473       }
474     };
475     ServerCallHandler<Holder, Holder> handler = (call, headers) -> new Listener<Holder>() {
476       @Override
477       public void onMessage(Holder message) {
478         requestFlowOrder.add("handler");
479         call.sendMessage(message);
480       }
481     };
482 
483     MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder()
484         .setType(MethodType.UNKNOWN)
485         .setFullMethodName("basic/wrapped")
486         .setRequestMarshaller(dummyMarshaller)
487         .setResponseMarshaller(dummyMarshaller)
488         .build();
489     ServerServiceDefinition serviceDef = ServerServiceDefinition.builder(
490             new ServiceDescriptor("basic", wrappedMethod))
491         .addMethod(wrappedMethod, handler).build();
492     ServerServiceDefinition intercepted = ServerInterceptors.useMarshalledMessages(serviceDef,
493         requestMarshaller, responseMarshaller);
494     ServerMethodDefinition<String, String> serverMethod =
495         (ServerMethodDefinition<String, String>) intercepted.getMethod("basic/wrapped");
496     ServerCall<String, String> serverCall = new NoopServerCall<>();
497     serverMethod.getServerCallHandler().startCall(serverCall, headers).onMessage("TestMessage");
498 
499     assertEquals(Arrays.asList("RequestStream",  "handler", "ResponseParse"), requestFlowOrder);
500   }
501 
502   @SuppressWarnings("unchecked")
getSoleMethod( ServerServiceDefinition serviceDef)503   private static ServerMethodDefinition<String, Integer> getSoleMethod(
504       ServerServiceDefinition serviceDef) {
505     if (serviceDef.getMethods().size() != 1) {
506       throw new AssertionError("Not exactly one method present");
507     }
508     return (ServerMethodDefinition<String, Integer>) getOnlyElement(serviceDef.getMethods());
509   }
510 
511   @SuppressWarnings("unchecked")
getMethod( ServerServiceDefinition serviceDef, String name)512   private static ServerMethodDefinition<String, Integer> getMethod(
513       ServerServiceDefinition serviceDef, String name) {
514     return (ServerMethodDefinition<String, Integer>) serviceDef.getMethod(name);
515   }
516 
anyCallHandler()517   private ServerCallHandler<String, Integer> anyCallHandler() {
518     return ArgumentMatchers.any();
519   }
520 
521   private static class NoopInterceptor implements ServerInterceptor {
522     @Override
interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next)523     public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
524         ServerCall<ReqT, RespT> call,
525         Metadata headers,
526         ServerCallHandler<ReqT, RespT> next) {
527       return next.startCall(call, headers);
528     }
529   }
530 
531   private static class Holder {
532     private final InputStream inputStream;
533 
Holder(InputStream inputStream)534     Holder(InputStream inputStream) {
535       this.inputStream = inputStream;
536     }
537 
get()538     public InputStream get() {
539       return inputStream;
540     }
541   }
542 }
543