/* * Copyright 2014 The gRPC Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package io.grpc; import static com.google.common.truth.Truth.assertThat; import static java.util.concurrent.TimeUnit.NANOSECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import io.grpc.ClientInterceptors.CheckedForwardingClientCall; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.testing.TestMethodDescriptors; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; /** Unit tests for {@link ClientInterceptors}. */ @RunWith(JUnit4.class) public class ClientInterceptorsTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private Channel channel; private BaseClientCall call = new BaseClientCall(); private final MethodDescriptor method = TestMethodDescriptors.voidMethod(); /** * Sets up mocks. */ @Before public void setUp() { when(channel.newCall( ArgumentMatchers.>any(), any(CallOptions.class))) .thenReturn(call); } @Test(expected = NullPointerException.class) public void npeForNullChannel() { ClientInterceptors.intercept(null, Arrays.asList()); } @Test(expected = NullPointerException.class) public void npeForNullInterceptorList() { ClientInterceptors.intercept(channel, (List) null); } @Test(expected = NullPointerException.class) public void npeForNullInterceptor() { ClientInterceptors.intercept(channel, (ClientInterceptor) null); } @Test public void noop() { assertSame(channel, ClientInterceptors.intercept(channel, Arrays.asList())); } @Test public void channelAndInterceptorCalled() { ClientInterceptor interceptor = mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor())); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); CallOptions callOptions = CallOptions.DEFAULT; // First call assertSame(call, intercepted.newCall(method, callOptions)); verify(channel).newCall(same(method), same(callOptions)); verify(interceptor) .interceptCall(same(method), same(callOptions), ArgumentMatchers.any()); verifyNoMoreInteractions(channel, interceptor); // Second call assertSame(call, intercepted.newCall(method, callOptions)); verify(channel, times(2)).newCall(same(method), same(callOptions)); verify(interceptor, times(2)) .interceptCall(same(method), same(callOptions), ArgumentMatchers.any()); verifyNoMoreInteractions(channel, interceptor); } @Test public void callNextTwice() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { // Calling next twice is permitted, although should only rarely be useful. assertSame(call, next.newCall(method, callOptions)); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT)); verifyNoMoreInteractions(channel); } @Test public void ordered() { final List order = new ArrayList<>(); channel = new Channel() { @SuppressWarnings("unchecked") @Override public ClientCall newCall( MethodDescriptor method, CallOptions callOptions) { order.add("channel"); return (ClientCall) call; } @Override public String authority() { return null; } }; ClientInterceptor interceptor1 = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { order.add("i1"); return next.newCall(method, callOptions); } }; ClientInterceptor interceptor2 = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { order.add("i2"); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); assertEquals(Arrays.asList("i2", "i1", "channel"), order); } @Test public void orderedForward() { final List order = new ArrayList<>(); channel = new Channel() { @SuppressWarnings("unchecked") @Override public ClientCall newCall( MethodDescriptor method, CallOptions callOptions) { order.add("channel"); return (ClientCall) call; } @Override public String authority() { return null; } }; ClientInterceptor interceptor1 = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { order.add("i1"); return next.newCall(method, callOptions); } }; ClientInterceptor interceptor2 = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { order.add("i2"); return next.newCall(method, callOptions); } }; Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2); assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); assertEquals(Arrays.asList("i1", "i2", "channel"), order); } @Test public void callOptions() { final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS); final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS); assertNotSame(initialCallOptions, newCallOptions); ClientInterceptor interceptor = mock(ClientInterceptor.class, delegatesTo(new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { return next.newCall(method, newCallOptions); } })); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); intercepted.newCall(method, initialCallOptions); verify(interceptor) .interceptCall(same(method), same(initialCallOptions), ArgumentMatchers.any()); verify(channel).newCall(same(method), same(newCallOptions)); } @Test public void addOutboundHeaders() { final Metadata.Key credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER); ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { ClientCall call = next.newCall(method, callOptions); return new SimpleForwardingClientCall(call) { @Override public void start(ClientCall.Listener responseListener, Metadata headers) { headers.put(credKey, "abcd"); super.start(responseListener, headers); } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener listener = mock(ClientCall.Listener.class); ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); // start() on the intercepted call will eventually reach the call created by the real channel interceptedCall.start(listener, new Metadata()); // The headers passed to the real channel call will contain the information inserted by the // interceptor. assertSame(listener, call.listener); assertEquals("abcd", call.headers.get(credKey)); } @Test public void examineInboundHeaders() { final List examinedHeaders = new ArrayList<>(); ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { ClientCall call = next.newCall(method, callOptions); return new SimpleForwardingClientCall(call) { @Override public void start(ClientCall.Listener responseListener, Metadata headers) { super.start(new SimpleForwardingClientCallListener(responseListener) { @Override public void onHeaders(Metadata headers) { examinedHeaders.add(headers); super.onHeaders(headers); } }, headers); } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener listener = mock(ClientCall.Listener.class); ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); interceptedCall.start(listener, new Metadata()); // Capture the underlying call listener that will receive headers from the transport. Metadata inboundHeaders = new Metadata(); // Simulate that a headers arrives on the underlying call listener. call.listener.onHeaders(inboundHeaders); assertThat(examinedHeaders).contains(inboundHeaders); } @Test public void normalCall() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { ClientCall call = next.newCall(method, callOptions); return new SimpleForwardingClientCall(call) { }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); assertNotSame(call, interceptedCall); @SuppressWarnings("unchecked") ClientCall.Listener listener = mock(ClientCall.Listener.class); Metadata headers = new Metadata(); interceptedCall.start(listener, headers); assertSame(listener, call.listener); assertSame(headers, call.headers); interceptedCall.sendMessage(null /*request*/); assertThat(call.messages).containsExactly((String) null); interceptedCall.halfClose(); assertTrue(call.halfClosed); interceptedCall.request(1); assertThat(call.requests).containsExactly(1); } @Test public void exceptionInStart() { final Exception error = new Exception("emulated error"); ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { ClientCall call = next.newCall(method, callOptions); return new CheckedForwardingClientCall(call) { @Override protected void checkedStart(ClientCall.Listener responseListener, Metadata headers) throws Exception { throw error; // delegate().start will not be called } }; } }; Channel intercepted = ClientInterceptors.intercept(channel, interceptor); @SuppressWarnings("unchecked") ClientCall.Listener listener = mock(ClientCall.Listener.class); ClientCall interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); assertNotSame(call, interceptedCall); interceptedCall.start(listener, new Metadata()); interceptedCall.sendMessage(null /*request*/); interceptedCall.halfClose(); interceptedCall.request(1); call.done = true; ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); verify(listener).onClose(captor.capture(), any(Metadata.class)); assertSame(error, captor.getValue().getCause()); // Make sure nothing bad happens after the exception. ClientCall noop = ((CheckedForwardingClientCall)interceptedCall).delegate(); // Should not throw, even on bad input noop.cancel("Cancel for test", null); noop.start(null, null); noop.request(-1); noop.halfClose(); noop.sendMessage(null); assertFalse(noop.isReady()); } @Test public void authorityIsDelegated() { ClientInterceptor interceptor = new ClientInterceptor() { @Override public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { return next.newCall(method, callOptions); } }; when(channel.authority()).thenReturn("auth"); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertEquals("auth", intercepted.authority()); } @Test public void customOptionAccessible() { CallOptions.Key customOption = CallOptions.Key.create("custom"); CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value"); ArgumentCaptor passedOptions = ArgumentCaptor.forClass(CallOptions.class); ClientInterceptor interceptor = mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor())); Channel intercepted = ClientInterceptors.intercept(channel, interceptor); assertSame(call, intercepted.newCall(method, callOptions)); verify(channel).newCall(same(method), same(callOptions)); verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class)); assertSame("value", passedOptions.getValue().getOption(customOption)); } private static class NoopInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall(MethodDescriptor method, CallOptions callOptions, Channel next) { return next.newCall(method, callOptions); } } private static class BaseClientCall extends ClientCall { private boolean started; private boolean done; private ClientCall.Listener listener; private Metadata headers; private List requests = new ArrayList<>(); private List messages = new ArrayList<>(); private boolean halfClosed; @Override public void start(ClientCall.Listener listener, Metadata headers) { checkNotDone(); started = true; this.listener = listener; this.headers = headers; } @Override public void request(int numMessages) { checkNotDone(); checkStarted(); requests.add(numMessages); } @Override public void cancel(String message, Throwable cause) { checkNotDone(); } @Override public void halfClose() { checkNotDone(); checkStarted(); this.halfClosed = true; } @Override public void sendMessage(String message) { checkNotDone(); checkStarted(); messages.add(message); } private void checkNotDone() { if (done) { throw new IllegalStateException("no more methods should be called"); } } private void checkStarted() { if (!started) { throw new IllegalStateException("should have called start"); } } } }