1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc; 18 19 import static com.google.common.collect.Iterables.getOnlyElement; 20 import static org.junit.Assert.assertEquals; 21 import static org.junit.Assert.assertSame; 22 import static org.junit.Assert.assertTrue; 23 import static org.mockito.AdditionalAnswers.delegatesTo; 24 import static org.mockito.ArgumentMatchers.same; 25 import static org.mockito.Mockito.mock; 26 import static org.mockito.Mockito.times; 27 import static org.mockito.Mockito.verify; 28 import static org.mockito.Mockito.verifyNoInteractions; 29 import static org.mockito.Mockito.verifyNoMoreInteractions; 30 31 import io.grpc.MethodDescriptor.Marshaller; 32 import io.grpc.MethodDescriptor.MethodType; 33 import io.grpc.ServerCall.Listener; 34 import io.grpc.internal.NoopServerCall; 35 import java.io.ByteArrayInputStream; 36 import java.io.InputStream; 37 import java.util.ArrayList; 38 import java.util.Arrays; 39 import java.util.List; 40 import org.junit.After; 41 import org.junit.Before; 42 import org.junit.Rule; 43 import org.junit.Test; 44 import org.junit.rules.ExpectedException; 45 import org.junit.runner.RunWith; 46 import org.junit.runners.JUnit4; 47 import org.mockito.ArgumentMatchers; 48 import org.mockito.Mock; 49 import org.mockito.Mockito; 50 import org.mockito.junit.MockitoJUnit; 51 import org.mockito.junit.MockitoRule; 52 53 /** Unit tests for {@link ServerInterceptors}. */ 54 @RunWith(JUnit4.class) 55 public class ServerInterceptorsTest { 56 @Rule 57 public final MockitoRule mocks = MockitoJUnit.rule(); 58 59 @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 60 @Rule 61 public final ExpectedException thrown = ExpectedException.none(); 62 63 @Mock 64 private Marshaller<String> requestMarshaller; 65 66 @Mock 67 private Marshaller<Integer> responseMarshaller; 68 69 @Mock 70 private ServerCallHandler<String, Integer> handler; 71 72 @Mock 73 private ServerCall.Listener<String> listener; 74 75 private MethodDescriptor<String, Integer> flowMethod; 76 77 private ServerCall<String, Integer> call = new NoopServerCall<>(); 78 79 private ServerServiceDefinition serviceDefinition; 80 81 private final Metadata headers = new Metadata(); 82 83 /** Set up for test. */ 84 @Before setUp()85 public void setUp() { 86 flowMethod = MethodDescriptor.<String, Integer>newBuilder() 87 .setType(MethodType.UNKNOWN) 88 .setFullMethodName("basic/flow") 89 .setRequestMarshaller(requestMarshaller) 90 .setResponseMarshaller(responseMarshaller) 91 .build(); 92 93 Mockito.when( 94 handler.startCall( 95 ArgumentMatchers.<ServerCall<String, Integer>>any(), 96 ArgumentMatchers.<Metadata>any())) 97 .thenReturn(listener); 98 99 serviceDefinition = ServerServiceDefinition.builder(new ServiceDescriptor("basic", flowMethod)) 100 .addMethod(flowMethod, handler).build(); 101 } 102 103 /** Final checks for all tests. */ 104 @After makeSureExpectedMocksUnused()105 public void makeSureExpectedMocksUnused() { 106 verifyNoInteractions(requestMarshaller); 107 verifyNoInteractions(responseMarshaller); 108 verifyNoInteractions(listener); 109 } 110 111 @Test npeForNullServiceDefinition()112 public void npeForNullServiceDefinition() { 113 ServerServiceDefinition serviceDef = null; 114 List<ServerInterceptor> interceptors = Arrays.asList(); 115 thrown.expect(NullPointerException.class); 116 ServerInterceptors.intercept(serviceDef, interceptors); 117 } 118 119 @Test npeForNullInterceptorList()120 public void npeForNullInterceptorList() { 121 thrown.expect(NullPointerException.class); 122 ServerInterceptors.intercept(serviceDefinition, (List<ServerInterceptor>) null); 123 } 124 125 @Test npeForNullInterceptor()126 public void npeForNullInterceptor() { 127 List<ServerInterceptor> interceptors = Arrays.asList((ServerInterceptor) null); 128 thrown.expect(NullPointerException.class); 129 ServerInterceptors.intercept(serviceDefinition, interceptors); 130 } 131 132 @Test noop()133 public void noop() { 134 assertSame(serviceDefinition, 135 ServerInterceptors.intercept(serviceDefinition, Arrays.<ServerInterceptor>asList())); 136 } 137 138 @Test multipleInvocationsOfHandler()139 public void multipleInvocationsOfHandler() { 140 ServerInterceptor interceptor = 141 mock(ServerInterceptor.class, delegatesTo(new NoopInterceptor())); 142 ServerServiceDefinition intercepted 143 = ServerInterceptors.intercept(serviceDefinition, Arrays.asList(interceptor)); 144 assertSame(listener, 145 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 146 verify(interceptor).interceptCall(same(call), same(headers), anyCallHandler()); 147 verify(handler).startCall(call, headers); 148 verifyNoMoreInteractions(interceptor, handler); 149 150 assertSame(listener, 151 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 152 verify(interceptor, times(2)) 153 .interceptCall(same(call), same(headers), anyCallHandler()); 154 verify(handler, times(2)).startCall(call, headers); 155 verifyNoMoreInteractions(interceptor, handler); 156 } 157 158 @Test correctHandlerCalled()159 public void correctHandlerCalled() { 160 @SuppressWarnings("unchecked") 161 ServerCallHandler<String, Integer> handler2 = mock(ServerCallHandler.class); 162 MethodDescriptor<String, Integer> flowMethod2 = 163 flowMethod.toBuilder().setFullMethodName("basic/flow2").build(); 164 serviceDefinition = ServerServiceDefinition.builder( 165 new ServiceDescriptor("basic", flowMethod, flowMethod2)) 166 .addMethod(flowMethod, handler) 167 .addMethod(flowMethod2, handler2).build(); 168 ServerServiceDefinition intercepted = ServerInterceptors.intercept( 169 serviceDefinition, Arrays.<ServerInterceptor>asList(new NoopInterceptor())); 170 getMethod(intercepted, "basic/flow").getServerCallHandler().startCall(call, headers); 171 verify(handler).startCall(call, headers); 172 verifyNoMoreInteractions(handler); 173 verifyNoMoreInteractions(handler2); 174 175 getMethod(intercepted, "basic/flow2").getServerCallHandler().startCall(call, headers); 176 verify(handler2).startCall(call, headers); 177 verifyNoMoreInteractions(handler); 178 verifyNoMoreInteractions(handler2); 179 } 180 181 @Test callNextTwice()182 public void callNextTwice() { 183 ServerInterceptor interceptor = new ServerInterceptor() { 184 @Override 185 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 186 ServerCall<ReqT, RespT> call, 187 Metadata headers, 188 ServerCallHandler<ReqT, RespT> next) { 189 // Calling next twice is permitted, although should only rarely be useful. 190 assertSame(listener, next.startCall(call, headers)); 191 return next.startCall(call, headers); 192 } 193 }; 194 ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDefinition, 195 interceptor); 196 assertSame(listener, 197 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 198 verify(handler, times(2)).startCall(same(call), same(headers)); 199 verifyNoMoreInteractions(handler); 200 } 201 202 @Test ordered()203 public void ordered() { 204 final List<String> order = new ArrayList<>(); 205 handler = new ServerCallHandler<String, Integer>() { 206 @Override 207 public ServerCall.Listener<String> startCall( 208 ServerCall<String, Integer> call, 209 Metadata headers) { 210 order.add("handler"); 211 return listener; 212 } 213 }; 214 ServerInterceptor interceptor1 = new ServerInterceptor() { 215 @Override 216 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 217 ServerCall<ReqT, RespT> call, 218 Metadata headers, 219 ServerCallHandler<ReqT, RespT> next) { 220 order.add("i1"); 221 return next.startCall(call, headers); 222 } 223 }; 224 ServerInterceptor interceptor2 = new ServerInterceptor() { 225 @Override 226 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 227 ServerCall<ReqT, RespT> call, 228 Metadata headers, 229 ServerCallHandler<ReqT, RespT> next) { 230 order.add("i2"); 231 return next.startCall(call, headers); 232 } 233 }; 234 ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder( 235 new ServiceDescriptor("basic", flowMethod)) 236 .addMethod(flowMethod, handler).build(); 237 ServerServiceDefinition intercepted = ServerInterceptors.intercept( 238 serviceDefinition, Arrays.asList(interceptor1, interceptor2)); 239 assertSame(listener, 240 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 241 assertEquals(Arrays.asList("i2", "i1", "handler"), order); 242 } 243 244 @Test orderedForward()245 public void orderedForward() { 246 final List<String> order = new ArrayList<>(); 247 handler = new ServerCallHandler<String, Integer>() { 248 @Override 249 public ServerCall.Listener<String> startCall( 250 ServerCall<String, Integer> call, 251 Metadata headers) { 252 order.add("handler"); 253 return listener; 254 } 255 }; 256 ServerInterceptor interceptor1 = new ServerInterceptor() { 257 @Override 258 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 259 ServerCall<ReqT, RespT> call, 260 Metadata headers, 261 ServerCallHandler<ReqT, RespT> next) { 262 order.add("i1"); 263 return next.startCall(call, headers); 264 } 265 }; 266 ServerInterceptor interceptor2 = new ServerInterceptor() { 267 @Override 268 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 269 ServerCall<ReqT, RespT> call, 270 Metadata headers, 271 ServerCallHandler<ReqT, RespT> next) { 272 order.add("i2"); 273 return next.startCall(call, headers); 274 } 275 }; 276 ServerServiceDefinition serviceDefinition = ServerServiceDefinition.builder( 277 new ServiceDescriptor("basic", flowMethod)) 278 .addMethod(flowMethod, handler).build(); 279 ServerServiceDefinition intercepted = ServerInterceptors.interceptForward( 280 serviceDefinition, interceptor1, interceptor2); 281 assertSame(listener, 282 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 283 assertEquals(Arrays.asList("i1", "i2", "handler"), order); 284 } 285 286 @Test argumentsPassed()287 public void argumentsPassed() { 288 final ServerCall<String, Integer> call2 = new NoopServerCall<>(); 289 @SuppressWarnings("unchecked") 290 final ServerCall.Listener<String> listener2 = mock(ServerCall.Listener.class); 291 292 ServerInterceptor interceptor = new ServerInterceptor() { 293 @SuppressWarnings("unchecked") // Lot's of casting for no benefit. Not intended use. 294 @Override 295 public <R1, R2> ServerCall.Listener<R1> interceptCall( 296 ServerCall<R1, R2> call, 297 Metadata headers, 298 ServerCallHandler<R1, R2> next) { 299 assertSame(call, ServerInterceptorsTest.this.call); 300 assertSame(listener, 301 next.startCall((ServerCall<R1, R2>)call2, headers)); 302 return (ServerCall.Listener<R1>) listener2; 303 } 304 }; 305 ServerServiceDefinition intercepted = ServerInterceptors.intercept( 306 serviceDefinition, Arrays.asList(interceptor)); 307 assertSame(listener2, 308 getSoleMethod(intercepted).getServerCallHandler().startCall(call, headers)); 309 verify(handler).startCall(call2, headers); 310 } 311 312 @Test 313 @SuppressWarnings("unchecked") typedMarshalledMessages()314 public void typedMarshalledMessages() { 315 final List<String> order = new ArrayList<>(); 316 Marshaller<Holder> marshaller = new Marshaller<Holder>() { 317 @Override 318 public InputStream stream(Holder value) { 319 return value.get(); 320 } 321 322 @Override 323 public Holder parse(InputStream stream) { 324 return new Holder(stream); 325 } 326 }; 327 328 ServerCallHandler<Holder, Holder> handler2 = new ServerCallHandler<Holder, Holder>() { 329 @Override 330 public Listener<Holder> startCall(final ServerCall<Holder, Holder> call, 331 final Metadata headers) { 332 return new Listener<Holder>() { 333 @Override 334 public void onMessage(Holder message) { 335 order.add("handler"); 336 call.sendMessage(message); 337 } 338 }; 339 } 340 }; 341 342 MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder() 343 .setType(MethodType.UNKNOWN) 344 .setFullMethodName("basic/wrapped") 345 .setRequestMarshaller(marshaller) 346 .setResponseMarshaller(marshaller) 347 .build(); 348 ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( 349 new ServiceDescriptor("basic", wrappedMethod)) 350 .addMethod(wrappedMethod, handler2).build(); 351 352 ServerInterceptor interceptor1 = new ServerInterceptor() { 353 @Override 354 public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, 355 Metadata headers, 356 ServerCallHandler<ReqT, RespT> next) { 357 ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall 358 .SimpleForwardingServerCall<ReqT, RespT>(call) { 359 @Override 360 public void sendMessage(RespT message) { 361 order.add("i1sendMessage"); 362 assertTrue(message instanceof Holder); 363 super.sendMessage(message); 364 } 365 }; 366 367 ServerCall.Listener<ReqT> originalListener = next 368 .startCall(interceptedCall, headers); 369 return new ForwardingServerCallListener 370 .SimpleForwardingServerCallListener<ReqT>(originalListener) { 371 @Override 372 public void onMessage(ReqT message) { 373 order.add("i1onMessage"); 374 assertTrue(message instanceof Holder); 375 super.onMessage(message); 376 } 377 }; 378 } 379 }; 380 381 ServerInterceptor interceptor2 = new ServerInterceptor() { 382 @Override 383 public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, 384 Metadata headers, 385 ServerCallHandler<ReqT, RespT> next) { 386 ServerCall<ReqT, RespT> interceptedCall = new ForwardingServerCall 387 .SimpleForwardingServerCall<ReqT, RespT>(call) { 388 @Override 389 public void sendMessage(RespT message) { 390 order.add("i2sendMessage"); 391 assertTrue(message instanceof InputStream); 392 super.sendMessage(message); 393 } 394 }; 395 396 ServerCall.Listener<ReqT> originalListener = next 397 .startCall(interceptedCall, headers); 398 return new ForwardingServerCallListener 399 .SimpleForwardingServerCallListener<ReqT>(originalListener) { 400 @Override 401 public void onMessage(ReqT message) { 402 order.add("i2onMessage"); 403 assertTrue(message instanceof InputStream); 404 super.onMessage(message); 405 } 406 }; 407 } 408 }; 409 410 ServerServiceDefinition intercepted = ServerInterceptors.intercept(serviceDef, interceptor1); 411 ServerServiceDefinition inputStreamMessageService = ServerInterceptors 412 .useInputStreamMessages(intercepted); 413 ServerServiceDefinition intercepted2 = ServerInterceptors 414 .intercept(inputStreamMessageService, interceptor2); 415 ServerMethodDefinition<InputStream, InputStream> serverMethod = 416 (ServerMethodDefinition<InputStream, InputStream>) intercepted2.getMethod("basic/wrapped"); 417 ServerCall<InputStream, InputStream> call2 = new NoopServerCall<>(); 418 byte[] bytes = {}; 419 serverMethod 420 .getServerCallHandler() 421 .startCall(call2, headers) 422 .onMessage(new ByteArrayInputStream(bytes)); 423 assertEquals( 424 Arrays.asList("i2onMessage", "i1onMessage", "handler", "i1sendMessage", "i2sendMessage"), 425 order); 426 } 427 428 /** 429 * Tests the ServerInterceptors#useMarshalledMessages()} with two marshallers. Makes sure that 430 * on incoming request the request marshaller's stream method is called and on response the 431 * response marshaller's parse method is called 432 */ 433 @Test 434 @SuppressWarnings("unchecked") distinctMarshallerForRequestAndResponse()435 public void distinctMarshallerForRequestAndResponse() { 436 final List<String> requestFlowOrder = new ArrayList<>(); 437 438 final Marshaller<String> requestMarshaller = new Marshaller<String>() { 439 @Override 440 public InputStream stream(String value) { 441 requestFlowOrder.add("RequestStream"); 442 return null; 443 } 444 445 @Override 446 public String parse(InputStream stream) { 447 requestFlowOrder.add("RequestParse"); 448 return null; 449 } 450 }; 451 final Marshaller<String> responseMarshaller = new Marshaller<String>() { 452 @Override 453 public InputStream stream(String value) { 454 requestFlowOrder.add("ResponseStream"); 455 return null; 456 } 457 458 @Override 459 public String parse(InputStream stream) { 460 requestFlowOrder.add("ResponseParse"); 461 return null; 462 } 463 }; 464 final Marshaller<Holder> dummyMarshaller = new Marshaller<Holder>() { 465 @Override 466 public InputStream stream(Holder value) { 467 return value.get(); 468 } 469 470 @Override 471 public Holder parse(InputStream stream) { 472 return new Holder(stream); 473 } 474 }; 475 ServerCallHandler<Holder, Holder> handler = (call, headers) -> new Listener<Holder>() { 476 @Override 477 public void onMessage(Holder message) { 478 requestFlowOrder.add("handler"); 479 call.sendMessage(message); 480 } 481 }; 482 483 MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder() 484 .setType(MethodType.UNKNOWN) 485 .setFullMethodName("basic/wrapped") 486 .setRequestMarshaller(dummyMarshaller) 487 .setResponseMarshaller(dummyMarshaller) 488 .build(); 489 ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( 490 new ServiceDescriptor("basic", wrappedMethod)) 491 .addMethod(wrappedMethod, handler).build(); 492 ServerServiceDefinition intercepted = ServerInterceptors.useMarshalledMessages(serviceDef, 493 requestMarshaller, responseMarshaller); 494 ServerMethodDefinition<String, String> serverMethod = 495 (ServerMethodDefinition<String, String>) intercepted.getMethod("basic/wrapped"); 496 ServerCall<String, String> serverCall = new NoopServerCall<>(); 497 serverMethod.getServerCallHandler().startCall(serverCall, headers).onMessage("TestMessage"); 498 499 assertEquals(Arrays.asList("RequestStream", "handler", "ResponseParse"), requestFlowOrder); 500 } 501 502 @SuppressWarnings("unchecked") getSoleMethod( ServerServiceDefinition serviceDef)503 private static ServerMethodDefinition<String, Integer> getSoleMethod( 504 ServerServiceDefinition serviceDef) { 505 if (serviceDef.getMethods().size() != 1) { 506 throw new AssertionError("Not exactly one method present"); 507 } 508 return (ServerMethodDefinition<String, Integer>) getOnlyElement(serviceDef.getMethods()); 509 } 510 511 @SuppressWarnings("unchecked") getMethod( ServerServiceDefinition serviceDef, String name)512 private static ServerMethodDefinition<String, Integer> getMethod( 513 ServerServiceDefinition serviceDef, String name) { 514 return (ServerMethodDefinition<String, Integer>) serviceDef.getMethod(name); 515 } 516 anyCallHandler()517 private ServerCallHandler<String, Integer> anyCallHandler() { 518 return ArgumentMatchers.any(); 519 } 520 521 private static class NoopInterceptor implements ServerInterceptor { 522 @Override interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next)523 public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( 524 ServerCall<ReqT, RespT> call, 525 Metadata headers, 526 ServerCallHandler<ReqT, RespT> next) { 527 return next.startCall(call, headers); 528 } 529 } 530 531 private static class Holder { 532 private final InputStream inputStream; 533 Holder(InputStream inputStream)534 Holder(InputStream inputStream) { 535 this.inputStream = inputStream; 536 } 537 get()538 public InputStream get() { 539 return inputStream; 540 } 541 } 542 } 543