/* * Copyright 2014 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc; import static com.google.common.collect.Iterables.getOnlyElement; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.ServerCall.Listener; import io.grpc.internal.NoopServerCall; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link ServerInterceptors}. */ @RunWith(JUnit4.class) public class ServerInterceptorsTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); @Mock private Marshaller requestMarshaller; @Mock private Marshaller responseMarshaller; @Mock private ServerCallHandler handler; @Mock private ServerCall.Listener listener; private MethodDescriptor flowMethod; private ServerCall call = new NoopServerCall<>(); private ServerServiceDefinition serviceDefinition; private final Metadata headers = new Metadata(); /** Set up for test. */ @Before public void setUp() { flowMethod = MethodDescriptor.newBuilder() .setType(MethodType.UNKNOWN) .setFullMethodName("basic/flow") .setRequestMarshaller(requestMarshaller) .setResponseMarshaller(responseMarshaller) .build(); Mockito.when( handler.startCall( ArgumentMatchers.>any(), ArgumentMatchers.any())) .thenReturn(listener); serviceDefinition = ServerServiceDefinition.builder(new ServiceDescriptor("basic", flowMethod)) .addMethod(flowMethod, handler).build(); } /** Final checks for all tests. */ @After public void makeSureExpectedMocksUnused() { verifyNoInteractions(requestMarshaller); verifyNoInteractions(responseMarshaller); verifyNoInteractions(listener); } @Test public void npeForNullServiceDefinition() { ServerServiceDefinition serviceDef = null; List interceptors = Arrays.asList(); thrown.expect(NullPointerException.class); ServerInterceptors.intercept(serviceDef, interceptors); } @Test public void npeForNullInterceptorList() { thrown.expect(NullPointerException.class); ServerInterceptors.intercept(serviceDefinition, (List) null); } @Test public void npeForNullInterceptor() { List interceptors = Arrays.asList((ServerInterceptor) null); thrown.expect(NullPointerException.class); ServerInterceptors.intercept(serviceDefinition, interceptors); } @Test public void noop() { assertSame(serviceDefinition, ServerInterceptors.intercept(serviceDefinition, Arrays.asList())); } @Test public void multipleInvocationsOfHandler() { ServerInterceptor interceptor = mock(ServerInterceptor.class, delegatesTo(new NoopInterceptor())); ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition, Arrays.asList(interceptor)); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); verify(interceptor).interceptCall(same(call), same(headers), anyCallHandler()); verify(handler).startCall(call, headers); verifyNoMoreInteractions(interceptor, handler); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); verify(interceptor, times(2)) .interceptCall(same(call), same(headers), anyCallHandler()); verify(handler, times(2)).startCall(call, headers); verifyNoMoreInteractions(interceptor, handler); } @Test public void correctHandlerCalled() { @SuppressWarnings("unchecked") ServerCallHandler handler2 = mock(ServerCallHandler.class); MethodDescriptor flowMethod2 = flowMethod.toBuilder().setFullMethodName("basic/flow2").build(); serviceDefinition = ServerServiceDefinition.builder( new ServiceDescriptor("basic", flowMethod, flowMethod2)) .addMethod(flowMethod, handler) .addMethod(flowMethod2, handler2).build(); ServerServiceDefinition intercepted = ServerInterceptors.intercept( serviceDefinition, Arrays.asList(new NoopInterceptor())); getMethod(intercepted, "basic/flow").getServerCallHandler().startCall(call, headers); verify(handler).startCall(call, headers); verifyNoMoreInteractions(handler); verifyNoMoreInteractions(handler2); getMethod(intercepted, "basic/flow2").getServerCallHandler().startCall(call, headers); verify(handler2).startCall(call, headers); verifyNoMoreInteractions(handler); verifyNoMoreInteractions(handler2); } @Test public void callNextTwice() { ServerInterceptor interceptor = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { // Calling next twice is permitted, although should only rarely be useful. assertSame(listener, next.startCall(call, headers)); return next.startCall(call, headers); } }; ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition, interceptor); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); verify(handler, times(2)).startCall(same(call), same(headers)); verifyNoMoreInteractions(handler); } @Test public void ordered() { final List order = new ArrayList<>(); handler = new ServerCallHandler() { @Override public ServerCall.Listener startCall( ServerCall call, Metadata headers) { order.add("handler"); return listener; } }; ServerInterceptor interceptor1 = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { order.add("i1"); return next.startCall(call, headers); } }; ServerInterceptor interceptor2 = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { order.add("i2"); return next.startCall(call, headers); } }; ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder( new ServiceDescriptor("basic", flowMethod)) .addMethod(flowMethod, handler).build(); ServerServiceDefinition intercepted = ServerInterceptors.intercept( serviceDefinition, Arrays.asList(interceptor1, interceptor2)); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); assertEquals(Arrays.asList("i2", "i1", "handler"), order); } @Test public void orderedForward() { final List order = new ArrayList<>(); handler = new ServerCallHandler() { @Override public ServerCall.Listener startCall( ServerCall call, Metadata headers) { order.add("handler"); return listener; } }; ServerInterceptor interceptor1 = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { order.add("i1"); return next.startCall(call, headers); } }; ServerInterceptor interceptor2 = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { order.add("i2"); return next.startCall(call, headers); } }; ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder( new ServiceDescriptor("basic", flowMethod)) .addMethod(flowMethod, handler).build(); ServerServiceDefinition intercepted = ServerInterceptors.interceptForward( serviceDefinition, interceptor1, interceptor2); assertSame(listener, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); assertEquals(Arrays.asList("i1", "i2", "handler"), order); } @Test public void argumentsPassed() { final ServerCall call2 = new NoopServerCall<>(); @SuppressWarnings("unchecked") final ServerCall.Listener listener2 = mock(ServerCall.Listener.class); ServerInterceptor interceptor = new ServerInterceptor() { @SuppressWarnings("unchecked") // Lot's of casting for no benefit. Not intended use. @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { assertSame(call, ServerInterceptorsTest.this.call); assertSame(listener, next.startCall((ServerCall)call2, headers)); return (ServerCall.Listener) listener2; } }; ServerServiceDefinition intercepted = ServerInterceptors.intercept( serviceDefinition, Arrays.asList(interceptor)); assertSame(listener2, getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); verify(handler).startCall(call2, headers); } @Test @SuppressWarnings("unchecked") public void typedMarshalledMessages() { final List order = new ArrayList<>(); Marshaller marshaller = new Marshaller() { @Override public InputStream stream(Holder value) { return value.get(); } @Override public Holder parse(InputStream stream) { return new Holder(stream); } }; ServerCallHandler handler2 = new ServerCallHandler() { @Override public Listener startCall(final ServerCall call, final Metadata headers) { return new Listener() { @Override public void onMessage(Holder message) { order.add("handler"); call.sendMessage(message); } }; } }; MethodDescriptor wrappedMethod = MethodDescriptor.newBuilder() .setType(MethodType.UNKNOWN) .setFullMethodName("basic/wrapped") .setRequestMarshaller(marshaller) .setResponseMarshaller(marshaller) .build(); ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( new ServiceDescriptor("basic", wrappedMethod)) .addMethod(wrappedMethod, handler2).build(); ServerInterceptor interceptor1 = new ServerInterceptor() { @Override public Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { ServerCall interceptedCall = new ForwardingServerCall .SimpleForwardingServerCall(call) { @Override public void sendMessage(RespT message) { order.add("i1sendMessage"); assertTrue(message instanceof Holder); super.sendMessage(message); } }; ServerCall.Listener originalListener = next .startCall(interceptedCall, headers); return new ForwardingServerCallListener .SimpleForwardingServerCallListener(originalListener) { @Override public void onMessage(ReqT message) { order.add("i1onMessage"); assertTrue(message instanceof Holder); super.onMessage(message); } }; } }; ServerInterceptor interceptor2 = new ServerInterceptor() { @Override public Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { ServerCall interceptedCall = new ForwardingServerCall .SimpleForwardingServerCall(call) { @Override public void sendMessage(RespT message) { order.add("i2sendMessage"); assertTrue(message instanceof InputStream); super.sendMessage(message); } }; ServerCall.Listener originalListener = next .startCall(interceptedCall, headers); return new ForwardingServerCallListener .SimpleForwardingServerCallListener(originalListener) { @Override public void onMessage(ReqT message) { order.add("i2onMessage"); assertTrue(message instanceof InputStream); super.onMessage(message); } }; } }; ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDef, interceptor1); ServerServiceDefinition inputStreamMessageService = ServerInterceptors .useInputStreamMessages(intercepted); ServerServiceDefinition intercepted2 = ServerInterceptors .intercept(inputStreamMessageService, interceptor2); ServerMethodDefinition serverMethod = (ServerMethodDefinition) intercepted2.getMethod("basic/wrapped"); ServerCall call2 = new NoopServerCall<>(); byte[] bytes = {}; serverMethod .getServerCallHandler() .startCall(call2, headers) .onMessage(new ByteArrayInputStream(bytes)); assertEquals( Arrays.asList("i2onMessage", "i1onMessage", "handler", "i1sendMessage", "i2sendMessage"), order); } /** * Tests the ServerInterceptors#useMarshalledMessages()} with two marshallers. Makes sure that * on incoming request the request marshaller's stream method is called and on response the * response marshaller's parse method is called */ @Test @SuppressWarnings("unchecked") public void distinctMarshallerForRequestAndResponse() { final List requestFlowOrder = new ArrayList<>(); final Marshaller requestMarshaller = new Marshaller() { @Override public InputStream stream(String value) { requestFlowOrder.add("RequestStream"); return null; } @Override public String parse(InputStream stream) { requestFlowOrder.add("RequestParse"); return null; } }; final Marshaller responseMarshaller = new Marshaller() { @Override public InputStream stream(String value) { requestFlowOrder.add("ResponseStream"); return null; } @Override public String parse(InputStream stream) { requestFlowOrder.add("ResponseParse"); return null; } }; final Marshaller dummyMarshaller = new Marshaller() { @Override public InputStream stream(Holder value) { return value.get(); } @Override public Holder parse(InputStream stream) { return new Holder(stream); } }; ServerCallHandler handler = (call, headers) -> new Listener() { @Override public void onMessage(Holder message) { requestFlowOrder.add("handler"); call.sendMessage(message); } }; MethodDescriptor wrappedMethod = MethodDescriptor.newBuilder() .setType(MethodType.UNKNOWN) .setFullMethodName("basic/wrapped") .setRequestMarshaller(dummyMarshaller) .setResponseMarshaller(dummyMarshaller) .build(); ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( new ServiceDescriptor("basic", wrappedMethod)) .addMethod(wrappedMethod, handler).build(); ServerServiceDefinition intercepted = ServerInterceptors.useMarshalledMessages(serviceDef, requestMarshaller, responseMarshaller); ServerMethodDefinition serverMethod = (ServerMethodDefinition) intercepted.getMethod("basic/wrapped"); ServerCall serverCall = new NoopServerCall<>(); serverMethod.getServerCallHandler().startCall(serverCall, headers).onMessage("TestMessage"); assertEquals(Arrays.asList("RequestStream", "handler", "ResponseParse"), requestFlowOrder); } @SuppressWarnings("unchecked") private static ServerMethodDefinition getSoleMethod( ServerServiceDefinition serviceDef) { if (serviceDef.getMethods().size() != 1) { throw new AssertionError("Not exactly one method present"); } return (ServerMethodDefinition) getOnlyElement(serviceDef.getMethods()); } @SuppressWarnings("unchecked") private static ServerMethodDefinition getMethod( ServerServiceDefinition serviceDef, String name) { return (ServerMethodDefinition) serviceDef.getMethod(name); } private ServerCallHandler anyCallHandler() { return ArgumentMatchers.any(); } private static class NoopInterceptor implements ServerInterceptor { @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { return next.startCall(call, headers); } } private static class Holder { private final InputStream inputStream; Holder(InputStream inputStream) { this.inputStream = inputStream; } public InputStream get() { return inputStream; } } }