1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static java.util.concurrent.TimeUnit.NANOSECONDS; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertFalse; 23 import static org.junit.Assert.assertNotSame; 24 import static org.junit.Assert.assertSame; 25 import static org.junit.Assert.assertTrue; 26 import static org.mockito.AdditionalAnswers.delegatesTo; 27 import static org.mockito.ArgumentMatchers.any; 28 import static org.mockito.ArgumentMatchers.isA; 29 import static org.mockito.ArgumentMatchers.same; 30 import static org.mockito.Mockito.mock; 31 import static org.mockito.Mockito.times; 32 import static org.mockito.Mockito.verify; 33 import static org.mockito.Mockito.verifyNoMoreInteractions; 34 import static org.mockito.Mockito.when; 35 36 import io.grpc.ClientInterceptors.CheckedForwardingClientCall; 37 import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; 38 import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; 39 import io.grpc.testing.TestMethodDescriptors; 40 import java.util.ArrayList; 41 import java.util.Arrays; 42 import java.util.List; 43 import org.junit.Before; 44 import org.junit.Rule; 45 import org.junit.Test; 46 import org.junit.runner.RunWith; 47 import org.junit.runners.JUnit4; 48 import org.mockito.ArgumentCaptor; 49 import org.mockito.ArgumentMatchers; 50 import org.mockito.Mock; 51 import org.mockito.junit.MockitoJUnit; 52 import org.mockito.junit.MockitoRule; 53 54 /** Unit tests for {@link ClientInterceptors}. */ 55 @RunWith(JUnit4.class) 56 public class ClientInterceptorsTest { 57 58 @Rule 59 public final MockitoRule mocks = MockitoJUnit.rule(); 60 61 @Mock 62 private Channel channel; 63 64 private BaseClientCall call = new BaseClientCall(); 65 66 private final MethodDescriptor<Void, Void> method = TestMethodDescriptors.voidMethod(); 67 68 /** 69 * Sets up mocks. 70 */ setUp()71 @Before public void setUp() { 72 when(channel.newCall( 73 ArgumentMatchers.<MethodDescriptor<String, Integer>>any(), any(CallOptions.class))) 74 .thenReturn(call); 75 } 76 77 @Test(expected = NullPointerException.class) npeForNullChannel()78 public void npeForNullChannel() { 79 ClientInterceptors.intercept(null, Arrays.<ClientInterceptor>asList()); 80 } 81 82 @Test(expected = NullPointerException.class) npeForNullInterceptorList()83 public void npeForNullInterceptorList() { 84 ClientInterceptors.intercept(channel, (List<ClientInterceptor>) null); 85 } 86 87 @Test(expected = NullPointerException.class) npeForNullInterceptor()88 public void npeForNullInterceptor() { 89 ClientInterceptors.intercept(channel, (ClientInterceptor) null); 90 } 91 92 @Test noop()93 public void noop() { 94 assertSame(channel, ClientInterceptors.intercept(channel, Arrays.<ClientInterceptor>asList())); 95 } 96 97 @Test channelAndInterceptorCalled()98 public void channelAndInterceptorCalled() { 99 ClientInterceptor interceptor = 100 mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor())); 101 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 102 CallOptions callOptions = CallOptions.DEFAULT; 103 // First call 104 assertSame(call, intercepted.newCall(method, callOptions)); 105 verify(channel).newCall(same(method), same(callOptions)); 106 verify(interceptor) 107 .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any()); 108 verifyNoMoreInteractions(channel, interceptor); 109 // Second call 110 assertSame(call, intercepted.newCall(method, callOptions)); 111 verify(channel, times(2)).newCall(same(method), same(callOptions)); 112 verify(interceptor, times(2)) 113 .interceptCall(same(method), same(callOptions), ArgumentMatchers.<Channel>any()); 114 verifyNoMoreInteractions(channel, interceptor); 115 } 116 117 @Test callNextTwice()118 public void callNextTwice() { 119 ClientInterceptor interceptor = new ClientInterceptor() { 120 @Override 121 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 122 MethodDescriptor<ReqT, RespT> method, 123 CallOptions callOptions, 124 Channel next) { 125 // Calling next twice is permitted, although should only rarely be useful. 126 assertSame(call, next.newCall(method, callOptions)); 127 return next.newCall(method, callOptions); 128 } 129 }; 130 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 131 assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); 132 verify(channel, times(2)).newCall(same(method), same(CallOptions.DEFAULT)); 133 verifyNoMoreInteractions(channel); 134 } 135 136 @Test ordered()137 public void ordered() { 138 final List<String> order = new ArrayList<>(); 139 channel = new Channel() { 140 @SuppressWarnings("unchecked") 141 @Override 142 public <ReqT, RespT> ClientCall<ReqT, RespT> newCall( 143 MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) { 144 order.add("channel"); 145 return (ClientCall<ReqT, RespT>) call; 146 } 147 148 @Override 149 public String authority() { 150 return null; 151 } 152 }; 153 ClientInterceptor interceptor1 = new ClientInterceptor() { 154 @Override 155 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 156 MethodDescriptor<ReqT, RespT> method, 157 CallOptions callOptions, 158 Channel next) { 159 order.add("i1"); 160 return next.newCall(method, callOptions); 161 } 162 }; 163 ClientInterceptor interceptor2 = new ClientInterceptor() { 164 @Override 165 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 166 MethodDescriptor<ReqT, RespT> method, 167 CallOptions callOptions, 168 Channel next) { 169 order.add("i2"); 170 return next.newCall(method, callOptions); 171 } 172 }; 173 Channel intercepted = ClientInterceptors.intercept(channel, interceptor1, interceptor2); 174 assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); 175 assertEquals(Arrays.asList("i2", "i1", "channel"), order); 176 } 177 178 @Test orderedForward()179 public void orderedForward() { 180 final List<String> order = new ArrayList<>(); 181 channel = new Channel() { 182 @SuppressWarnings("unchecked") 183 @Override 184 public <ReqT, RespT> ClientCall<ReqT, RespT> newCall( 185 MethodDescriptor<ReqT, RespT> method, CallOptions callOptions) { 186 order.add("channel"); 187 return (ClientCall<ReqT, RespT>) call; 188 } 189 190 @Override 191 public String authority() { 192 return null; 193 } 194 }; 195 ClientInterceptor interceptor1 = new ClientInterceptor() { 196 @Override 197 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 198 MethodDescriptor<ReqT, RespT> method, 199 CallOptions callOptions, 200 Channel next) { 201 order.add("i1"); 202 return next.newCall(method, callOptions); 203 } 204 }; 205 ClientInterceptor interceptor2 = new ClientInterceptor() { 206 @Override 207 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 208 MethodDescriptor<ReqT, RespT> method, 209 CallOptions callOptions, 210 Channel next) { 211 order.add("i2"); 212 return next.newCall(method, callOptions); 213 } 214 }; 215 Channel intercepted = ClientInterceptors.interceptForward(channel, interceptor1, interceptor2); 216 assertSame(call, intercepted.newCall(method, CallOptions.DEFAULT)); 217 assertEquals(Arrays.asList("i1", "i2", "channel"), order); 218 } 219 220 @Test callOptions()221 public void callOptions() { 222 final CallOptions initialCallOptions = CallOptions.DEFAULT.withDeadlineAfter(100, NANOSECONDS); 223 final CallOptions newCallOptions = initialCallOptions.withDeadlineAfter(300, NANOSECONDS); 224 assertNotSame(initialCallOptions, newCallOptions); 225 ClientInterceptor interceptor = 226 mock(ClientInterceptor.class, delegatesTo(new ClientInterceptor() { 227 @Override 228 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 229 MethodDescriptor<ReqT, RespT> method, 230 CallOptions callOptions, 231 Channel next) { 232 return next.newCall(method, newCallOptions); 233 } 234 })); 235 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 236 intercepted.newCall(method, initialCallOptions); 237 verify(interceptor) 238 .interceptCall(same(method), same(initialCallOptions), ArgumentMatchers.<Channel>any()); 239 verify(channel).newCall(same(method), same(newCallOptions)); 240 } 241 242 @Test addOutboundHeaders()243 public void addOutboundHeaders() { 244 final Metadata.Key<String> credKey = Metadata.Key.of("Cred", Metadata.ASCII_STRING_MARSHALLER); 245 ClientInterceptor interceptor = new ClientInterceptor() { 246 @Override 247 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 248 MethodDescriptor<ReqT, RespT> method, 249 CallOptions callOptions, 250 Channel next) { 251 ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); 252 return new SimpleForwardingClientCall<ReqT, RespT>(call) { 253 @Override 254 public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) { 255 headers.put(credKey, "abcd"); 256 super.start(responseListener, headers); 257 } 258 }; 259 } 260 }; 261 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 262 @SuppressWarnings("unchecked") 263 ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); 264 ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); 265 // start() on the intercepted call will eventually reach the call created by the real channel 266 interceptedCall.start(listener, new Metadata()); 267 // The headers passed to the real channel call will contain the information inserted by the 268 // interceptor. 269 assertSame(listener, call.listener); 270 assertEquals("abcd", call.headers.get(credKey)); 271 } 272 273 @Test examineInboundHeaders()274 public void examineInboundHeaders() { 275 final List<Metadata> examinedHeaders = new ArrayList<>(); 276 ClientInterceptor interceptor = new ClientInterceptor() { 277 @Override 278 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 279 MethodDescriptor<ReqT, RespT> method, 280 CallOptions callOptions, 281 Channel next) { 282 ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); 283 return new SimpleForwardingClientCall<ReqT, RespT>(call) { 284 @Override 285 public void start(ClientCall.Listener<RespT> responseListener, Metadata headers) { 286 super.start(new SimpleForwardingClientCallListener<RespT>(responseListener) { 287 @Override 288 public void onHeaders(Metadata headers) { 289 examinedHeaders.add(headers); 290 super.onHeaders(headers); 291 } 292 }, headers); 293 } 294 }; 295 } 296 }; 297 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 298 @SuppressWarnings("unchecked") 299 ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); 300 ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); 301 interceptedCall.start(listener, new Metadata()); 302 // Capture the underlying call listener that will receive headers from the transport. 303 304 Metadata inboundHeaders = new Metadata(); 305 // Simulate that a headers arrives on the underlying call listener. 306 call.listener.onHeaders(inboundHeaders); 307 assertThat(examinedHeaders).contains(inboundHeaders); 308 } 309 310 @Test normalCall()311 public void normalCall() { 312 ClientInterceptor interceptor = new ClientInterceptor() { 313 @Override 314 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 315 MethodDescriptor<ReqT, RespT> method, 316 CallOptions callOptions, 317 Channel next) { 318 ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); 319 return new SimpleForwardingClientCall<ReqT, RespT>(call) { }; 320 } 321 }; 322 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 323 ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); 324 assertNotSame(call, interceptedCall); 325 @SuppressWarnings("unchecked") 326 ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); 327 Metadata headers = new Metadata(); 328 interceptedCall.start(listener, headers); 329 assertSame(listener, call.listener); 330 assertSame(headers, call.headers); 331 interceptedCall.sendMessage(null /*request*/); 332 assertThat(call.messages).containsExactly((String) null); 333 interceptedCall.halfClose(); 334 assertTrue(call.halfClosed); 335 interceptedCall.request(1); 336 assertThat(call.requests).containsExactly(1); 337 } 338 339 @Test exceptionInStart()340 public void exceptionInStart() { 341 final Exception error = new Exception("emulated error"); 342 ClientInterceptor interceptor = new ClientInterceptor() { 343 @Override 344 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 345 MethodDescriptor<ReqT, RespT> method, 346 CallOptions callOptions, 347 Channel next) { 348 ClientCall<ReqT, RespT> call = next.newCall(method, callOptions); 349 return new CheckedForwardingClientCall<ReqT, RespT>(call) { 350 @Override 351 protected void checkedStart(ClientCall.Listener<RespT> responseListener, Metadata headers) 352 throws Exception { 353 throw error; 354 // delegate().start will not be called 355 } 356 }; 357 } 358 }; 359 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 360 @SuppressWarnings("unchecked") 361 ClientCall.Listener<Void> listener = mock(ClientCall.Listener.class); 362 ClientCall<Void, Void> interceptedCall = intercepted.newCall(method, CallOptions.DEFAULT); 363 assertNotSame(call, interceptedCall); 364 interceptedCall.start(listener, new Metadata()); 365 interceptedCall.sendMessage(null /*request*/); 366 interceptedCall.halfClose(); 367 interceptedCall.request(1); 368 call.done = true; 369 ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); 370 verify(listener).onClose(captor.capture(), any(Metadata.class)); 371 assertSame(error, captor.getValue().getCause()); 372 373 // Make sure nothing bad happens after the exception. 374 ClientCall<?, ?> noop = ((CheckedForwardingClientCall<?, ?>)interceptedCall).delegate(); 375 // Should not throw, even on bad input 376 noop.cancel("Cancel for test", null); 377 noop.start(null, null); 378 noop.request(-1); 379 noop.halfClose(); 380 noop.sendMessage(null); 381 assertFalse(noop.isReady()); 382 } 383 384 @Test authorityIsDelegated()385 public void authorityIsDelegated() { 386 ClientInterceptor interceptor = new ClientInterceptor() { 387 @Override 388 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( 389 MethodDescriptor<ReqT, RespT> method, 390 CallOptions callOptions, 391 Channel next) { 392 return next.newCall(method, callOptions); 393 } 394 }; 395 396 when(channel.authority()).thenReturn("auth"); 397 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 398 assertEquals("auth", intercepted.authority()); 399 } 400 401 @Test customOptionAccessible()402 public void customOptionAccessible() { 403 CallOptions.Key<String> customOption = CallOptions.Key.create("custom"); 404 CallOptions callOptions = CallOptions.DEFAULT.withOption(customOption, "value"); 405 ArgumentCaptor<CallOptions> passedOptions = ArgumentCaptor.forClass(CallOptions.class); 406 ClientInterceptor interceptor = 407 mock(ClientInterceptor.class, delegatesTo(new NoopInterceptor())); 408 409 Channel intercepted = ClientInterceptors.intercept(channel, interceptor); 410 411 assertSame(call, intercepted.newCall(method, callOptions)); 412 verify(channel).newCall(same(method), same(callOptions)); 413 414 verify(interceptor).interceptCall(same(method), passedOptions.capture(), isA(Channel.class)); 415 assertSame("value", passedOptions.getValue().getOption(customOption)); 416 } 417 418 private static class NoopInterceptor implements ClientInterceptor { 419 @Override interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next)420 public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, 421 CallOptions callOptions, Channel next) { 422 return next.newCall(method, callOptions); 423 } 424 } 425 426 private static class BaseClientCall extends ClientCall<String, Integer> { 427 private boolean started; 428 private boolean done; 429 private ClientCall.Listener<Integer> listener; 430 private Metadata headers; 431 private List<Integer> requests = new ArrayList<>(); 432 private List<String> messages = new ArrayList<>(); 433 private boolean halfClosed; 434 435 @Override start(ClientCall.Listener<Integer> listener, Metadata headers)436 public void start(ClientCall.Listener<Integer> listener, Metadata headers) { 437 checkNotDone(); 438 started = true; 439 this.listener = listener; 440 this.headers = headers; 441 } 442 443 @Override request(int numMessages)444 public void request(int numMessages) { 445 checkNotDone(); 446 checkStarted(); 447 requests.add(numMessages); 448 } 449 450 @Override cancel(String message, Throwable cause)451 public void cancel(String message, Throwable cause) { 452 checkNotDone(); 453 } 454 455 @Override halfClose()456 public void halfClose() { 457 checkNotDone(); 458 checkStarted(); 459 this.halfClosed = true; 460 } 461 462 @Override sendMessage(String message)463 public void sendMessage(String message) { 464 checkNotDone(); 465 checkStarted(); 466 messages.add(message); 467 } 468 checkNotDone()469 private void checkNotDone() { 470 if (done) { 471 throw new IllegalStateException("no more methods should be called"); 472 } 473 } 474 checkStarted()475 private void checkStarted() { 476 if (!started) { 477 throw new IllegalStateException("should have called start"); 478 } 479 } 480 } 481 } 482