• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static java.util.concurrent.TimeUnit.NANOSECONDS;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertFalse;
23 import static org.junit.Assert.assertNotSame;
24 import static org.junit.Assert.assertSame;
25 import static org.junit.Assert.assertTrue;
26 import static org.mockito.AdditionalAnswers.delegatesTo;
27 import static org.mockito.ArgumentMatchers.any;
28 import static org.mockito.ArgumentMatchers.isA;
29 import static org.mockito.ArgumentMatchers.same;
30 import static org.mockito.Mockito.mock;
31 import static org.mockito.Mockito.times;
32 import static org.mockito.Mockito.verify;
33 import static org.mockito.Mockito.verifyNoMoreInteractions;
34 import static org.mockito.Mockito.when;
35 
36 import io.grpc.ClientInterceptors.CheckedForwardingClientCall;
37 import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
38 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
39 import io.grpc.testing.TestMethodDescriptors;
40 import java.util.ArrayList;
41 import java.util.Arrays;
42 import java.util.List;
43 import org.junit.Before;
44 import org.junit.Rule;
45 import org.junit.Test;
46 import org.junit.runner.RunWith;
47 import org.junit.runners.JUnit4;
48 import org.mockito.ArgumentCaptor;
49 import org.mockito.ArgumentMatchers;
50 import org.mockito.Mock;
51 import org.mockito.junit.MockitoJUnit;
52 import org.mockito.junit.MockitoRule;
53 
54 /** Unit tests for {@link ClientInterceptors}. */
55 @RunWith(JUnit4.class)
56 public class ClientInterceptorsTest {
57 
58   @Rule
59   public final MockitoRule mocks = MockitoJUnit.rule();
60 
61   @Mock
62   private Channel channel;
63 
64   private BaseClientCall call = new BaseClientCall();
65 
66   private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod();
67 
68   /**
69    * Sets up mocks.
70    */
setUp()71   @Before public void setUp() {
72     when(channel.newCall(
73             ArgumentMatchers.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class)))
74         .thenReturn(call);
75   }
76 
77   @Test(expected = NullPointerException.class)
npeForNullChannel()78   public void npeForNullChannel() {
79     ClientInterceptors.intercept(null, Arrays.<ClientInterceptor>asList());
80   }
81 
82   @Test(expected = NullPointerException.class)
npeForNullInterceptorList()83   public void npeForNullInterceptorList() {
84     ClientInterceptors.intercept(channel, (List<ClientInterceptor>) null);
85   }
86 
87   @Test(expected = NullPointerException.class)
npeForNullInterceptor()88   public void npeForNullInterceptor() {
89     ClientInterceptors.intercept(channel, (ClientInterceptor) null);
90   }
91 
92   @Test
noop()93   public void noop() {
94     assertSame(channel, ClientInterceptors.intercept(channel, Arrays.<ClientInterceptor>asList()));
95   }
96 
97   @Test
channelAndInterceptorCalled()98   public void channelAndInterceptorCalled() {
99     ClientInterceptor interceptor =
100         mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));
101     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
102     CallOptions callOptions = CallOptions.DEFAULT;
103     // First call
104     assertSame(call, intercepted.newCall(method, callOptions));
105     verify(channel).newCall(same(method), same(callOptions));
106     verify(interceptor)
107         .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any());
108     verifyNoMoreInteractions(channel, interceptor);
109     // Second call
110     assertSame(call, intercepted.newCall(method, callOptions));
111     verify(channel, times(2)).newCall(same(method), same(callOptions));
112     verify(interceptor, times(2))
113         .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any());
114     verifyNoMoreInteractions(channel, interceptor);
115   }
116 
117   @Test
callNextTwice()118   public void callNextTwice() {
119     ClientInterceptor interceptor = new ClientInterceptor() {
120       @Override
121       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
122           MethodDescriptor<ReqT, RespT> method,
123           CallOptions callOptions,
124           Channel next) {
125         // Calling next twice is permitted, although should only rarely be useful.
126         assertSame(call, next.newCall(method, callOptions));
127         return next.newCall(method, callOptions);
128       }
129     };
130     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
131     assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
132     verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT));
133     verifyNoMoreInteractions(channel);
134   }
135 
136   @Test
ordered()137   public void ordered() {
138     final List<String> order = new ArrayList<>();
139     channel = new Channel() {
140       @SuppressWarnings("unchecked")
141       @Override
142       public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
143           MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
144         order.add("channel");
145         return (ClientCall<ReqT, RespT>) call;
146       }
147 
148       @Override
149       public String authority() {
150         return null;
151       }
152     };
153     ClientInterceptor interceptor1 = new ClientInterceptor() {
154       @Override
155       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
156           MethodDescriptor<ReqT, RespT> method,
157           CallOptions callOptions,
158           Channel next) {
159         order.add("i1");
160         return next.newCall(method, callOptions);
161       }
162     };
163     ClientInterceptor interceptor2 = new ClientInterceptor() {
164       @Override
165       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
166           MethodDescriptor<ReqT, RespT> method,
167           CallOptions callOptions,
168           Channel next) {
169         order.add("i2");
170         return next.newCall(method, callOptions);
171       }
172     };
173     Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2);
174     assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
175     assertEquals(Arrays.asList("i2", "i1", "channel"), order);
176   }
177 
178   @Test
orderedForward()179   public void orderedForward() {
180     final List<String> order = new ArrayList<>();
181     channel = new Channel() {
182       @SuppressWarnings("unchecked")
183       @Override
184       public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
185           MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) {
186         order.add("channel");
187         return (ClientCall<ReqT, RespT>) call;
188       }
189 
190       @Override
191       public String authority() {
192         return null;
193       }
194     };
195     ClientInterceptor interceptor1 = new ClientInterceptor() {
196       @Override
197       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
198           MethodDescriptor<ReqT, RespT> method,
199           CallOptions callOptions,
200           Channel next) {
201         order.add("i1");
202         return next.newCall(method, callOptions);
203       }
204     };
205     ClientInterceptor interceptor2 = new ClientInterceptor() {
206       @Override
207       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
208           MethodDescriptor<ReqT, RespT> method,
209           CallOptions callOptions,
210           Channel next) {
211         order.add("i2");
212         return next.newCall(method, callOptions);
213       }
214     };
215     Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2);
216     assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT));
217     assertEquals(Arrays.asList("i1", "i2", "channel"), order);
218   }
219 
220   @Test
callOptions()221   public void callOptions() {
222     final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS);
223     final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS);
224     assertNotSame(initialCallOptions, newCallOptions);
225     ClientInterceptor interceptor =
226         mock(ClientInterceptor.class, delegatesTo(new ClientInterceptor() {
227           @Override
228           public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
229               MethodDescriptor<ReqT, RespT> method,
230               CallOptions callOptions,
231               Channel next) {
232             return next.newCall(method, newCallOptions);
233           }
234         }));
235     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
236     intercepted.newCall(method, initialCallOptions);
237     verify(interceptor)
238         .interceptCall(same(method), same(initialCallOptions), ArgumentMatchers.<Channel>any());
239     verify(channel).newCall(same(method), same(newCallOptions));
240   }
241 
242   @Test
addOutboundHeaders()243   public void addOutboundHeaders() {
244     final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER);
245     ClientInterceptor interceptor = new ClientInterceptor() {
246       @Override
247       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
248           MethodDescriptor<ReqT, RespT> method,
249           CallOptions callOptions,
250           Channel next) {
251         ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
252         return new SimpleForwardingClientCall<ReqT, RespT>(call) {
253           @Override
254           public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
255             headers.put(credKey, "abcd");
256             super.start(responseListener, headers);
257           }
258         };
259       }
260     };
261     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
262     @SuppressWarnings("unchecked")
263     ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
264     ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
265     // start() on the intercepted call will eventually reach the call created by the real channel
266     interceptedCall.start(listener, new Metadata());
267     // The headers passed to the real channel call will contain the information inserted by the
268     // interceptor.
269     assertSame(listener, call.listener);
270     assertEquals("abcd", call.headers.get(credKey));
271   }
272 
273   @Test
examineInboundHeaders()274   public void examineInboundHeaders() {
275     final List<Metadata> examinedHeaders = new ArrayList<>();
276     ClientInterceptor interceptor = new ClientInterceptor() {
277       @Override
278       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
279           MethodDescriptor<ReqT, RespT> method,
280           CallOptions callOptions,
281           Channel next) {
282         ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
283         return new SimpleForwardingClientCall<ReqT, RespT>(call) {
284           @Override
285           public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) {
286             super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) {
287               @Override
288               public void onHeaders(Metadata headers) {
289                 examinedHeaders.add(headers);
290                 super.onHeaders(headers);
291               }
292             }, headers);
293           }
294         };
295       }
296     };
297     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
298     @SuppressWarnings("unchecked")
299     ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
300     ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
301     interceptedCall.start(listener, new Metadata());
302     // Capture the underlying call listener that will receive headers from the transport.
303 
304     Metadata inboundHeaders = new Metadata();
305     // Simulate that a headers arrives on the underlying call listener.
306     call.listener.onHeaders(inboundHeaders);
307     assertThat(examinedHeaders).contains(inboundHeaders);
308   }
309 
310   @Test
normalCall()311   public void normalCall() {
312     ClientInterceptor interceptor = new ClientInterceptor() {
313       @Override
314       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
315           MethodDescriptor<ReqT, RespT> method,
316           CallOptions callOptions,
317           Channel next) {
318         ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
319         return new SimpleForwardingClientCall<ReqT, RespT>(call) { };
320       }
321     };
322     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
323     ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
324     assertNotSame(call, interceptedCall);
325     @SuppressWarnings("unchecked")
326     ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
327     Metadata headers = new Metadata();
328     interceptedCall.start(listener, headers);
329     assertSame(listener, call.listener);
330     assertSame(headers, call.headers);
331     interceptedCall.sendMessage(null /*request*/);
332     assertThat(call.messages).containsExactly((String) null);
333     interceptedCall.halfClose();
334     assertTrue(call.halfClosed);
335     interceptedCall.request(1);
336     assertThat(call.requests).containsExactly(1);
337   }
338 
339   @Test
exceptionInStart()340   public void exceptionInStart() {
341     final Exception error = new Exception("emulated error");
342     ClientInterceptor interceptor = new ClientInterceptor() {
343       @Override
344       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
345           MethodDescriptor<ReqT, RespT> method,
346           CallOptions callOptions,
347           Channel next) {
348         ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
349         return new CheckedForwardingClientCall<ReqT, RespT>(call) {
350           @Override
351           protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadata headers)
352               throws Exception {
353             throw error;
354             // delegate().start will not be called
355           }
356         };
357       }
358     };
359     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
360     @SuppressWarnings("unchecked")
361     ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class);
362     ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT);
363     assertNotSame(call, interceptedCall);
364     interceptedCall.start(listener, new Metadata());
365     interceptedCall.sendMessage(null /*request*/);
366     interceptedCall.halfClose();
367     interceptedCall.request(1);
368     call.done = true;
369     ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
370     verify(listener).onClose(captor.capture(), any(Metadata.class));
371     assertSame(error, captor.getValue().getCause());
372 
373     // Make sure nothing bad happens after the exception.
374     ClientCall<?, ?> noop = ((CheckedForwardingClientCall<?, ?>)interceptedCall).delegate();
375     // Should not throw, even on bad input
376     noop.cancel("Cancel for test", null);
377     noop.start(null, null);
378     noop.request(-1);
379     noop.halfClose();
380     noop.sendMessage(null);
381     assertFalse(noop.isReady());
382   }
383 
384   @Test
authorityIsDelegated()385   public void authorityIsDelegated() {
386     ClientInterceptor interceptor = new ClientInterceptor() {
387       @Override
388       public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
389           MethodDescriptor<ReqT, RespT> method,
390           CallOptions callOptions,
391           Channel next) {
392         return next.newCall(method, callOptions);
393       }
394     };
395 
396     when(channel.authority()).thenReturn("auth");
397     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
398     assertEquals("auth", intercepted.authority());
399   }
400 
401   @Test
customOptionAccessible()402   public void customOptionAccessible() {
403     CallOptions.Key<String> customOption = CallOptions.Key.create("custom");
404     CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value");
405     ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class);
406     ClientInterceptor interceptor =
407         mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor()));
408 
409     Channel intercepted = ClientInterceptors.intercept(channel, interceptor);
410 
411     assertSame(call, intercepted.newCall(method, callOptions));
412     verify(channel).newCall(same(method), same(callOptions));
413 
414     verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class));
415     assertSame("value", passedOptions.getValue().getOption(customOption));
416   }
417 
418   private static class NoopInterceptor implements ClientInterceptor {
419     @Override
interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next)420     public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
421         CallOptions callOptions, Channel next) {
422       return next.newCall(method, callOptions);
423     }
424   }
425 
426   private static class BaseClientCall extends ClientCall<String, Integer> {
427     private boolean started;
428     private boolean done;
429     private ClientCall.Listener<Integer> listener;
430     private Metadata headers;
431     private List<Integer> requests = new ArrayList<>();
432     private List<String> messages = new ArrayList<>();
433     private boolean halfClosed;
434 
435     @Override
start(ClientCall.Listener<Integer> listener, Metadata headers)436     public void start(ClientCall.Listener<Integer> listener, Metadata headers) {
437       checkNotDone();
438       started = true;
439       this.listener = listener;
440       this.headers = headers;
441     }
442 
443     @Override
request(int numMessages)444     public void request(int numMessages) {
445       checkNotDone();
446       checkStarted();
447       requests.add(numMessages);
448     }
449 
450     @Override
cancel(String message, Throwable cause)451     public void cancel(String message, Throwable cause) {
452       checkNotDone();
453     }
454 
455     @Override
halfClose()456     public void halfClose() {
457       checkNotDone();
458       checkStarted();
459       this.halfClosed = true;
460     }
461 
462     @Override
sendMessage(String message)463     public void sendMessage(String message) {
464       checkNotDone();
465       checkStarted();
466       messages.add(message);
467     }
468 
checkNotDone()469     private void checkNotDone() {
470       if (done) {
471         throw new IllegalStateException("no more methods should be called");
472       }
473     }
474 
checkStarted()475     private void checkStarted() {
476       if (!started) {
477         throw new IllegalStateException("should have called start");
478       }
479     }
480   }
481 }
482