• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.netty;
18 
19 import static com.google.common.base.Charsets.UTF_8;
20 import static com.google.common.truth.Truth.assertThat;
21 import static com.google.common.truth.TruthJUnit.assume;
22 import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE;
23 import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS;
24 import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS;
25 import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED;
26 import static io.grpc.internal.GrpcUtil.USER_AGENT_KEY;
27 import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE;
28 import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED;
29 import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED;
30 import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
31 import static org.junit.Assert.assertEquals;
32 import static org.junit.Assert.assertFalse;
33 import static org.junit.Assert.assertNotNull;
34 import static org.junit.Assert.assertNull;
35 import static org.junit.Assert.assertSame;
36 import static org.junit.Assert.assertTrue;
37 import static org.junit.Assert.fail;
38 
39 import com.google.common.base.Ticker;
40 import com.google.common.io.ByteStreams;
41 import com.google.common.util.concurrent.SettableFuture;
42 import io.grpc.Attributes;
43 import io.grpc.CallOptions;
44 import io.grpc.ChannelLogger;
45 import io.grpc.ClientStreamTracer;
46 import io.grpc.Grpc;
47 import io.grpc.InternalChannelz;
48 import io.grpc.Metadata;
49 import io.grpc.MethodDescriptor;
50 import io.grpc.MethodDescriptor.Marshaller;
51 import io.grpc.ServerStreamTracer;
52 import io.grpc.Status;
53 import io.grpc.Status.Code;
54 import io.grpc.StatusException;
55 import io.grpc.internal.ClientStream;
56 import io.grpc.internal.ClientStreamListener;
57 import io.grpc.internal.ClientTransport;
58 import io.grpc.internal.FakeClock;
59 import io.grpc.internal.FixedObjectPool;
60 import io.grpc.internal.GrpcUtil;
61 import io.grpc.internal.ManagedClientTransport;
62 import io.grpc.internal.ServerListener;
63 import io.grpc.internal.ServerStream;
64 import io.grpc.internal.ServerStreamListener;
65 import io.grpc.internal.ServerTransport;
66 import io.grpc.internal.ServerTransportListener;
67 import io.grpc.internal.TransportTracer;
68 import io.grpc.internal.testing.TestUtils;
69 import io.grpc.netty.NettyChannelBuilder.LocalSocketPicker;
70 import io.grpc.netty.NettyTestUtil.TrackingObjectPoolForTest;
71 import io.grpc.testing.TlsTesting;
72 import io.netty.channel.Channel;
73 import io.netty.channel.ChannelConfig;
74 import io.netty.channel.ChannelDuplexHandler;
75 import io.netty.channel.ChannelFactory;
76 import io.netty.channel.ChannelHandler;
77 import io.netty.channel.ChannelHandlerContext;
78 import io.netty.channel.ChannelOption;
79 import io.netty.channel.EventLoopGroup;
80 import io.netty.channel.ReflectiveChannelFactory;
81 import io.netty.channel.local.LocalChannel;
82 import io.netty.channel.nio.NioEventLoopGroup;
83 import io.netty.channel.socket.SocketChannelConfig;
84 import io.netty.channel.socket.nio.NioServerSocketChannel;
85 import io.netty.channel.socket.nio.NioSocketChannel;
86 import io.netty.handler.codec.http2.StreamBufferingEncoder;
87 import io.netty.handler.ssl.ClientAuth;
88 import io.netty.handler.ssl.SslContext;
89 import io.netty.util.AsciiString;
90 import java.io.ByteArrayInputStream;
91 import java.io.IOException;
92 import java.io.InputStream;
93 import java.net.InetSocketAddress;
94 import java.net.SocketAddress;
95 import java.util.ArrayList;
96 import java.util.Collections;
97 import java.util.HashMap;
98 import java.util.List;
99 import java.util.Map;
100 import java.util.concurrent.ExecutionException;
101 import java.util.concurrent.LinkedBlockingQueue;
102 import java.util.concurrent.TimeUnit;
103 import java.util.concurrent.TimeoutException;
104 import javax.annotation.Nullable;
105 import javax.net.ssl.SSLException;
106 import javax.net.ssl.SSLHandshakeException;
107 import org.junit.After;
108 import org.junit.Rule;
109 import org.junit.Test;
110 import org.junit.runner.RunWith;
111 import org.junit.runners.JUnit4;
112 import org.mockito.Mock;
113 import org.mockito.junit.MockitoJUnit;
114 import org.mockito.junit.MockitoRule;
115 
116 /**
117  * Tests for {@link NettyClientTransport}.
118  */
119 @RunWith(JUnit4.class)
120 public class NettyClientTransportTest {
121   @Rule public final MockitoRule mocks = MockitoJUnit.rule();
122 
123   private static final SslContext SSL_CONTEXT = createSslContext();
124 
125   @Mock
126   private ManagedClientTransport.Listener clientTransportListener;
127 
128   private final List<NettyClientTransport> transports = new ArrayList<>();
129   private final LinkedBlockingQueue<Attributes> serverTransportAttributesList =
130       new LinkedBlockingQueue<>();
131   private final NioEventLoopGroup group = new NioEventLoopGroup(1);
132   private final EchoServerListener serverListener = new EchoServerListener();
133   private final InternalChannelz channelz = new InternalChannelz();
134   private Runnable tooManyPingsRunnable = new Runnable() {
135     // Throwing is useless in this method, because Netty doesn't propagate the exception
136     @Override public void run() {}
137   };
138   private Attributes eagAttributes = Attributes.EMPTY;
139 
140   private ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(SSL_CONTEXT);
141 
142   private InetSocketAddress address;
143   private String authority;
144   private NettyServer server;
145 
146   @After
teardown()147   public void teardown() throws Exception {
148     for (NettyClientTransport transport : transports) {
149       transport.shutdown(Status.UNAVAILABLE);
150     }
151 
152     if (server != null) {
153       server.shutdown();
154     }
155 
156     group.shutdownGracefully(0, 10, TimeUnit.SECONDS);
157   }
158 
159   @Test
testToString()160   public void testToString() throws Exception {
161     address = TestUtils.testServerAddress(new InetSocketAddress(12345));
162     authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
163     String s = newTransport(newNegotiator()).toString();
164     transports.clear();
165     assertTrue("Unexpected: " + s, s.contains("NettyClientTransport"));
166     assertTrue("Unexpected: " + s, s.contains(address.toString()));
167   }
168 
169   @Test
addDefaultUserAgent()170   public void addDefaultUserAgent() throws Exception {
171     startServer();
172     NettyClientTransport transport = newTransport(newNegotiator());
173     callMeMaybe(transport.start(clientTransportListener));
174 
175     // Send a single RPC and wait for the response.
176     new Rpc(transport).halfClose().waitForResponse();
177 
178     // Verify that the received headers contained the User-Agent.
179     assertEquals(1, serverListener.streamListeners.size());
180 
181     Metadata headers = serverListener.streamListeners.get(0).headers;
182     assertEquals(GrpcUtil.getGrpcUserAgent("netty", null), headers.get(USER_AGENT_KEY));
183   }
184 
185   @Test
setSoLingerChannelOption()186   public void setSoLingerChannelOption() throws IOException {
187     startServer();
188     Map<ChannelOption<?>, Object> channelOptions = new HashMap<>();
189     // set SO_LINGER option
190     int soLinger = 123;
191     channelOptions.put(ChannelOption.SO_LINGER, soLinger);
192     NettyClientTransport transport = new NettyClientTransport(
193         address, new ReflectiveChannelFactory<>(NioSocketChannel.class), channelOptions, group,
194         newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
195         GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1L, false, authority,
196         null /* user agent */, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY,
197         new SocketPicker(), new FakeChannelLogger(), false, Ticker.systemTicker());
198     transports.add(transport);
199     callMeMaybe(transport.start(clientTransportListener));
200 
201     // verify SO_LINGER has been set
202     ChannelConfig config = transport.channel().config();
203     assertTrue(config instanceof SocketChannelConfig);
204     assertEquals(soLinger, ((SocketChannelConfig) config).getSoLinger());
205   }
206 
207   @Test
overrideDefaultUserAgent()208   public void overrideDefaultUserAgent() throws Exception {
209     startServer();
210     NettyClientTransport transport = newTransport(newNegotiator(),
211         DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true);
212     callMeMaybe(transport.start(clientTransportListener));
213 
214     new Rpc(transport, new Metadata()).halfClose().waitForResponse();
215 
216     // Verify that the received headers contained the User-Agent.
217     assertEquals(1, serverListener.streamListeners.size());
218     Metadata receivedHeaders = serverListener.streamListeners.get(0).headers;
219     assertEquals(GrpcUtil.getGrpcUserAgent("netty", "testUserAgent"),
220         receivedHeaders.get(USER_AGENT_KEY));
221   }
222 
223   @Test
maxMessageSizeShouldBeEnforced()224   public void maxMessageSizeShouldBeEnforced() throws Throwable {
225     startServer();
226     // Allow the response payloads of up to 1 byte.
227     NettyClientTransport transport = newTransport(newNegotiator(),
228         1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null, true);
229     callMeMaybe(transport.start(clientTransportListener));
230 
231     try {
232       // Send a single RPC and wait for the response.
233       new Rpc(transport).halfClose().waitForResponse();
234       fail("Expected the stream to fail.");
235     } catch (ExecutionException e) {
236       Status status = Status.fromThrowable(e);
237       assertEquals(Code.RESOURCE_EXHAUSTED, status.getCode());
238       assertTrue("Missing exceeds maximum from: " + status.getDescription(),
239           status.getDescription().contains("exceeds maximum"));
240     }
241   }
242 
243   /**
244    * Verifies that we can create multiple TLS client transports from the same builder.
245    */
246   @Test
creatingMultipleTlsTransportsShouldSucceed()247   public void creatingMultipleTlsTransportsShouldSucceed() throws Exception {
248     startServer();
249 
250     // Create a couple client transports.
251     ProtocolNegotiator negotiator = newNegotiator();
252     for (int index = 0; index < 2; ++index) {
253       NettyClientTransport transport = newTransport(negotiator);
254       callMeMaybe(transport.start(clientTransportListener));
255     }
256 
257     // Send a single RPC on each transport.
258     final List<Rpc> rpcs = new ArrayList<>(transports.size());
259     for (NettyClientTransport transport : transports) {
260       rpcs.add(new Rpc(transport).halfClose());
261     }
262 
263     // Wait for the RPCs to complete.
264     for (Rpc rpc : rpcs) {
265       rpc.waitForResponse();
266     }
267   }
268 
269   @Test
negotiationFailurePropagatesToStatus()270   public void negotiationFailurePropagatesToStatus() throws Exception {
271     negotiator = ProtocolNegotiators.serverPlaintext();
272     startServer();
273 
274     final NoopProtocolNegotiator negotiator = new NoopProtocolNegotiator();
275     final NettyClientTransport transport = newTransport(negotiator);
276     callMeMaybe(transport.start(clientTransportListener));
277     final Status failureStatus = Status.UNAVAILABLE.withDescription("oh noes!");
278     transport.channel().eventLoop().execute(new Runnable() {
279       @Override
280       public void run() {
281         negotiator.handler.fail(transport.channel().pipeline().context(negotiator.handler),
282             failureStatus.asRuntimeException());
283       }
284     });
285 
286     Rpc rpc = new Rpc(transport).halfClose();
287     try {
288       rpc.waitForClose();
289       fail("expected exception");
290     } catch (ExecutionException ex) {
291       Status actual = ((StatusException) ex.getCause()).getStatus();
292       assertSame(failureStatus.getCode(), actual.getCode());
293       assertThat(actual.getDescription()).contains(failureStatus.getDescription());
294     }
295   }
296 
297   @Test
tlsNegotiationFailurePropagatesToStatus()298   public void tlsNegotiationFailurePropagatesToStatus() throws Exception {
299     InputStream serverCert = TlsTesting.loadCert("server1.pem");
300     InputStream serverKey = TlsTesting.loadCert("server1.key");
301     // Don't trust ca.pem, so that client auth fails
302     SslContext sslContext = GrpcSslContexts.forServer(serverCert, serverKey)
303         .clientAuth(ClientAuth.REQUIRE)
304         .build();
305     negotiator = ProtocolNegotiators.serverTls(sslContext);
306     startServer();
307 
308     InputStream caCert = TlsTesting.loadCert("ca.pem");
309     InputStream clientCert = TlsTesting.loadCert("client.pem");
310     InputStream clientKey = TlsTesting.loadCert("client.key");
311     SslContext clientContext = GrpcSslContexts.forClient()
312         .trustManager(caCert)
313         .keyManager(clientCert, clientKey)
314         .build();
315     ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext);
316     final NettyClientTransport transport = newTransport(negotiator);
317     callMeMaybe(transport.start(clientTransportListener));
318 
319     Rpc rpc = new Rpc(transport).halfClose();
320     try {
321       rpc.waitForClose();
322       fail("expected exception");
323     } catch (ExecutionException ex) {
324       StatusException sre = (StatusException) ex.getCause();
325       assertEquals(Status.Code.UNAVAILABLE, sre.getStatus().getCode());
326       if (sre.getCause() instanceof SSLHandshakeException) {
327         assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
328         assertThat(sre).hasCauseThat().hasMessageThat().contains("SSLV3_ALERT_HANDSHAKE_FAILURE");
329       } else {
330         // Client cert verification is after handshake in TLSv1.3
331         assertThat(sre).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class);
332         assertThat(sre).hasCauseThat().hasMessageThat().contains("CERTIFICATE_REQUIRED");
333       }
334     }
335   }
336 
337   @Test
channelExceptionDuringNegotiatonPropagatesToStatus()338   public void channelExceptionDuringNegotiatonPropagatesToStatus() throws Exception {
339     negotiator = ProtocolNegotiators.serverPlaintext();
340     startServer();
341 
342     NoopProtocolNegotiator negotiator = new NoopProtocolNegotiator();
343     NettyClientTransport transport = newTransport(negotiator);
344     callMeMaybe(transport.start(clientTransportListener));
345     final Status failureStatus = Status.UNAVAILABLE.withDescription("oh noes!");
346     transport.channel().pipeline().fireExceptionCaught(failureStatus.asRuntimeException());
347 
348     Rpc rpc = new Rpc(transport).halfClose();
349     try {
350       rpc.waitForClose();
351       fail("expected exception");
352     } catch (ExecutionException ex) {
353       assertSame(failureStatus, ((StatusException) ex.getCause()).getStatus());
354     }
355   }
356 
357   @Test
handlerExceptionDuringNegotiatonPropagatesToStatus()358   public void handlerExceptionDuringNegotiatonPropagatesToStatus() throws Exception {
359     negotiator = ProtocolNegotiators.serverPlaintext();
360     startServer();
361 
362     final NoopProtocolNegotiator negotiator = new NoopProtocolNegotiator();
363     final NettyClientTransport transport = newTransport(negotiator);
364     callMeMaybe(transport.start(clientTransportListener));
365     final Status failureStatus = Status.UNAVAILABLE.withDescription("oh noes!");
366     transport.channel().eventLoop().execute(new Runnable() {
367       @Override
368       public void run() {
369         try {
370           negotiator.handler.exceptionCaught(
371               transport.channel().pipeline().context(negotiator.handler),
372               failureStatus.asRuntimeException());
373         } catch (Exception ex) {
374           throw new RuntimeException(ex);
375         }
376       }
377     });
378 
379     Rpc rpc = new Rpc(transport).halfClose();
380     try {
381       rpc.waitForClose();
382       fail("expected exception");
383     } catch (ExecutionException ex) {
384       Status actual = ((StatusException) ex.getCause()).getStatus();
385       assertSame(failureStatus.getCode(), actual.getCode());
386       assertThat(actual.getDescription()).contains(failureStatus.getDescription());
387     }
388   }
389 
390   @Test
bufferedStreamsShouldBeClosedWhenConnectionTerminates()391   public void bufferedStreamsShouldBeClosedWhenConnectionTerminates() throws Exception {
392     // Only allow a single stream active at a time.
393     startServer(1, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE);
394 
395     NettyClientTransport transport = newTransport(newNegotiator());
396     callMeMaybe(transport.start(clientTransportListener));
397 
398     // Send a dummy RPC in order to ensure that the updated SETTINGS_MAX_CONCURRENT_STREAMS
399     // has been received by the remote endpoint.
400     new Rpc(transport).halfClose().waitForResponse();
401 
402     // Create 3 streams, but don't half-close. The transport will buffer the second and third.
403     Rpc[] rpcs = new Rpc[] { new Rpc(transport), new Rpc(transport), new Rpc(transport) };
404 
405     // Wait for the response for the stream that was actually created.
406     rpcs[0].waitForResponse();
407 
408     // Now forcibly terminate the connection from the server side.
409     serverListener.transports.get(0).channel().pipeline().firstContext().close();
410 
411     // Now wait for both listeners to be closed.
412     for (int i = 1; i < rpcs.length; i++) {
413       try {
414         rpcs[i].waitForClose();
415         fail("Expected the RPC to fail");
416       } catch (ExecutionException e) {
417         // Expected.
418         Throwable t = getRootCause(e);
419         // Make sure that the Http2ChannelClosedException got replaced with the real cause of
420         // the shutdown.
421         assertFalse(t instanceof StreamBufferingEncoder.Http2ChannelClosedException);
422       }
423     }
424   }
425 
426   public static class CantConstructChannel extends NioSocketChannel {
427     /** Constructor. It doesn't work. Feel free to try. But it doesn't work. */
CantConstructChannel()428     public CantConstructChannel() {
429       // Use an Error because we've seen cases of channels failing to construct due to classloading
430       // problems (like mixing different versions of Netty), and those involve Errors.
431       throw new CantConstructChannelError();
432     }
433   }
434 
435   private static class CantConstructChannelError extends Error {}
436 
437   @Test
failingToConstructChannelShouldFailGracefully()438   public void failingToConstructChannelShouldFailGracefully() throws Exception {
439     address = TestUtils.testServerAddress(new InetSocketAddress(12345));
440     authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
441     NettyClientTransport transport = new NettyClientTransport(
442         address, new ReflectiveChannelFactory<>(CantConstructChannel.class),
443         new HashMap<ChannelOption<?>, Object>(), group,
444         newNegotiator(), false, DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE,
445         GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, KEEPALIVE_TIME_NANOS_DISABLED, 1, false, authority,
446         null, tooManyPingsRunnable, new TransportTracer(), Attributes.EMPTY, new SocketPicker(),
447         new FakeChannelLogger(), false, Ticker.systemTicker());
448     transports.add(transport);
449 
450     // Should not throw
451     callMeMaybe(transport.start(clientTransportListener));
452 
453     // And RPCs and PINGs should fail cleanly, reporting the failure
454     Rpc rpc = new Rpc(transport);
455     try {
456       rpc.waitForResponse();
457       fail("Expected exception");
458     } catch (Exception ex) {
459       if (!(getRootCause(ex) instanceof CantConstructChannelError)) {
460         throw new AssertionError("Could not find expected error", ex);
461       }
462     }
463 
464     final SettableFuture<Object> pingResult = SettableFuture.create();
465     FakeClock clock = new FakeClock();
466     ClientTransport.PingCallback pingCallback = new ClientTransport.PingCallback() {
467       @Override
468       public void onSuccess(long roundTripTimeNanos) {
469         pingResult.set(roundTripTimeNanos);
470       }
471 
472       @Override
473       public void onFailure(Throwable cause) {
474         pingResult.setException(cause);
475       }
476     };
477     transport.ping(pingCallback, clock.getScheduledExecutorService());
478     assertFalse(pingResult.isDone());
479     clock.runDueTasks();
480     assertTrue(pingResult.isDone());
481     try {
482       pingResult.get();
483       fail("Expected exception");
484     } catch (Exception ex) {
485       if (!(getRootCause(ex) instanceof CantConstructChannelError)) {
486         throw new AssertionError("Could not find expected error", ex);
487       }
488     }
489   }
490 
491   @Test
channelFactoryShouldSetSocketOptionKeepAlive()492   public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception {
493     startServer();
494     NettyClientTransport transport = newTransport(newNegotiator(),
495         DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
496         TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L),
497         new ReflectiveChannelFactory<>(NioSocketChannel.class), group);
498 
499     callMeMaybe(transport.start(clientTransportListener));
500 
501     assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
502         .isTrue();
503   }
504 
505   @Test
channelFactoryShouldNNotSetSocketOptionKeepAlive()506   public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception {
507     startServer();
508     NettyClientTransport transport = newTransport(newNegotiator(),
509         DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true,
510         TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L),
511         new ReflectiveChannelFactory<>(LocalChannel.class), group);
512 
513     callMeMaybe(transport.start(clientTransportListener));
514 
515     assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE))
516         .isNull();
517   }
518 
519   @Test
maxHeaderListSizeShouldBeEnforcedOnClient()520   public void maxHeaderListSizeShouldBeEnforcedOnClient() throws Exception {
521     startServer();
522 
523     NettyClientTransport transport =
524         newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE, 1, null, true);
525     callMeMaybe(transport.start(clientTransportListener));
526 
527     try {
528       // Send a single RPC and wait for the response.
529       new Rpc(transport, new Metadata()).halfClose().waitForResponse();
530       fail("The stream should have been failed due to client received header exceeds header list"
531           + " size limit!");
532     } catch (Exception e) {
533       Throwable rootCause = getRootCause(e);
534       Status status = ((StatusException) rootCause).getStatus();
535       assertEquals(Status.Code.INTERNAL, status.getCode());
536       assertEquals("RST_STREAM closed stream. HTTP/2 error code: PROTOCOL_ERROR",
537           status.getDescription());
538     }
539   }
540 
541   @Test
maxHeaderListSizeShouldBeEnforcedOnServer()542   public void maxHeaderListSizeShouldBeEnforcedOnServer() throws Exception {
543     startServer(100, 1);
544 
545     NettyClientTransport transport = newTransport(newNegotiator());
546     callMeMaybe(transport.start(clientTransportListener));
547 
548     try {
549       // Send a single RPC and wait for the response.
550       new Rpc(transport, new Metadata()).halfClose().waitForResponse();
551       fail("The stream should have been failed due to server received header exceeds header list"
552           + " size limit!");
553     } catch (Exception e) {
554       Status status = Status.fromThrowable(e);
555       assertEquals(status.toString(), Status.Code.INTERNAL, status.getCode());
556     }
557   }
558 
559   @Test
getAttributes_negotiatorHandler()560   public void getAttributes_negotiatorHandler() throws Exception {
561     address = TestUtils.testServerAddress(new InetSocketAddress(12345));
562     authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
563 
564     NettyClientTransport transport = newTransport(new NoopProtocolNegotiator());
565     callMeMaybe(transport.start(clientTransportListener));
566 
567     assertNotNull(transport.getAttributes());
568   }
569 
570   @Test
getEagAttributes_negotiatorHandler()571   public void getEagAttributes_negotiatorHandler() throws Exception {
572     address = TestUtils.testServerAddress(new InetSocketAddress(12345));
573     authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
574 
575     NoopProtocolNegotiator npn = new NoopProtocolNegotiator();
576     eagAttributes = Attributes.newBuilder()
577         .set(Attributes.Key.create("trash"), "value")
578         .build();
579     NettyClientTransport transport = newTransport(npn);
580     callMeMaybe(transport.start(clientTransportListener));
581 
582     // EAG Attributes are available before the negotiation is complete
583     assertSame(eagAttributes, npn.grpcHandler.getEagAttributes());
584   }
585 
586   @Test
clientStreamGetsAttributes()587   public void clientStreamGetsAttributes() throws Exception {
588     startServer();
589     NettyClientTransport transport = newTransport(newNegotiator());
590     callMeMaybe(transport.start(clientTransportListener));
591     Rpc rpc = new Rpc(transport).halfClose();
592     rpc.waitForResponse();
593 
594     assertNotNull(rpc.stream.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION));
595     assertEquals(address, rpc.stream.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
596     Attributes serverTransportAttrs = serverTransportAttributesList.poll(1, TimeUnit.SECONDS);
597     assertNotNull(serverTransportAttrs);
598     SocketAddress clientAddr = serverTransportAttrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
599     assertNotNull(clientAddr);
600     assertEquals(clientAddr, rpc.stream.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR));
601   }
602 
603   @Test
keepAliveEnabled()604   public void keepAliveEnabled() throws Exception {
605     startServer();
606     NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE,
607         GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, true /* keep alive */);
608     callMeMaybe(transport.start(clientTransportListener));
609     Rpc rpc = new Rpc(transport).halfClose();
610     rpc.waitForResponse();
611 
612     assertNotNull(transport.keepAliveManager());
613   }
614 
615   @Test
keepAliveDisabled()616   public void keepAliveDisabled() throws Exception {
617     startServer();
618     NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE,
619         GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, false /* keep alive */);
620     callMeMaybe(transport.start(clientTransportListener));
621     Rpc rpc = new Rpc(transport).halfClose();
622     rpc.waitForResponse();
623 
624     assertNull(transport.keepAliveManager());
625   }
626 
627   @Test
keepAliveEnabled_shouldSetTcpUserTimeout()628   public void keepAliveEnabled_shouldSetTcpUserTimeout() throws Exception {
629     assume().that(Utils.isEpollAvailable()).isTrue();
630 
631     startServer();
632     EventLoopGroup epollGroup = Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP.create();
633     int keepAliveTimeMillis = 12345670;
634     int keepAliveTimeoutMillis = 1234567;
635     try {
636       NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE,
637           GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, true /* keep alive */,
638           TimeUnit.MILLISECONDS.toNanos(keepAliveTimeMillis),
639           TimeUnit.MILLISECONDS.toNanos(keepAliveTimeoutMillis),
640           new ReflectiveChannelFactory<>(Utils.DEFAULT_CLIENT_CHANNEL_TYPE), epollGroup);
641 
642       callMeMaybe(transport.start(clientTransportListener));
643 
644       ChannelOption<Integer> tcpUserTimeoutOption = Utils.maybeGetTcpUserTimeoutOption();
645       assertThat(tcpUserTimeoutOption).isNotNull();
646       // on some linux based system, the integer value may have error (usually +-1)
647       assertThat((double) transport.channel().config().getOption(tcpUserTimeoutOption))
648           .isWithin(5.0).of((double) keepAliveTimeoutMillis);
649     } finally {
650       epollGroup.shutdownGracefully();
651     }
652   }
653 
654   @Test
keepAliveDisabled_shouldNotSetTcpUserTimeout()655   public void keepAliveDisabled_shouldNotSetTcpUserTimeout() throws Exception {
656     assume().that(Utils.isEpollAvailable()).isTrue();
657 
658     startServer();
659     EventLoopGroup epollGroup = Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP.create();
660     int keepAliveTimeMillis = 12345670;
661     try {
662       long keepAliveTimeNanos = TimeUnit.MILLISECONDS.toNanos(keepAliveTimeMillis);
663       NettyClientTransport transport = newTransport(newNegotiator(), DEFAULT_MAX_MESSAGE_SIZE,
664           GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, null /* user agent */, false /* keep alive */,
665           keepAliveTimeNanos, keepAliveTimeNanos,
666           new ReflectiveChannelFactory<>(Utils.DEFAULT_CLIENT_CHANNEL_TYPE), epollGroup);
667 
668       callMeMaybe(transport.start(clientTransportListener));
669 
670       ChannelOption<Integer> tcpUserTimeoutOption = Utils.maybeGetTcpUserTimeoutOption();
671       assertThat(tcpUserTimeoutOption).isNotNull();
672       // default TCP_USER_TIMEOUT=0 (use the system default)
673       assertThat(transport.channel().config().getOption(tcpUserTimeoutOption)).isEqualTo(0);
674     } finally {
675       epollGroup.shutdownGracefully();
676     }
677   }
678 
679   /**
680    * Verifies that we can successfully build a server and client negotiator with tls and the
681    * executor passing in, and without resource leak after closing the negotiator.
682    */
683   @Test
tlsNegotiationServerExecutorShouldSucceed()684   public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
685     // initialize the client and server Executor pool
686     TrackingObjectPoolForTest serverExecutorPool = new TrackingObjectPoolForTest();
687     TrackingObjectPoolForTest clientExecutorPool = new TrackingObjectPoolForTest();
688     assertEquals(false, serverExecutorPool.isInUse());
689     assertEquals(false, clientExecutorPool.isInUse());
690 
691     InputStream serverCert = TlsTesting.loadCert("server1.pem");
692     InputStream serverKey = TlsTesting.loadCert("server1.key");
693     SslContext sslContext = GrpcSslContexts.forServer(serverCert, serverKey)
694         .clientAuth(ClientAuth.NONE)
695         .build();
696     negotiator = ProtocolNegotiators.serverTls(sslContext, serverExecutorPool);
697     startServer();
698     // after starting the server, the Executor in the server pool should be used
699     assertEquals(true, serverExecutorPool.isInUse());
700 
701     InputStream caCert = TlsTesting.loadCert("ca.pem");
702     InputStream clientCert = TlsTesting.loadCert("client.pem");
703     InputStream clientKey = TlsTesting.loadCert("client.key");
704     SslContext clientContext = GrpcSslContexts.forClient()
705         .trustManager(caCert)
706         .keyManager(clientCert, clientKey)
707         .build();
708     ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
709     // after starting the client, the Executor in the client pool should be used
710     assertEquals(true, clientExecutorPool.isInUse());
711     final NettyClientTransport transport = newTransport(negotiator);
712     callMeMaybe(transport.start(clientTransportListener));
713     Rpc rpc = new Rpc(transport).halfClose();
714     rpc.waitForResponse();
715     // closing the negotiators should return the executors back to pool, and release the resource
716     negotiator.close();
717     assertEquals(false, clientExecutorPool.isInUse());
718     this.negotiator.close();
719     assertEquals(false, serverExecutorPool.isInUse());
720   }
721 
getRootCause(Throwable t)722   private Throwable getRootCause(Throwable t) {
723     if (t.getCause() == null) {
724       return t;
725     }
726     return getRootCause(t.getCause());
727   }
728 
newNegotiator()729   private ProtocolNegotiator newNegotiator() throws IOException {
730     InputStream caCert = TlsTesting.loadCert("ca.pem");
731     SslContext clientContext = GrpcSslContexts.forClient().trustManager(caCert).build();
732     return ProtocolNegotiators.tls(clientContext);
733   }
734 
newTransport(ProtocolNegotiator negotiator)735   private NettyClientTransport newTransport(ProtocolNegotiator negotiator) {
736     return newTransport(negotiator, DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE,
737         null /* user agent */, true /* keep alive */);
738   }
739 
newTransport(ProtocolNegotiator negotiator, int maxMsgSize, int maxHeaderListSize, String userAgent, boolean enableKeepAlive)740   private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
741       int maxHeaderListSize, String userAgent, boolean enableKeepAlive) {
742     return newTransport(negotiator, maxMsgSize, maxHeaderListSize, userAgent, enableKeepAlive,
743         TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L),
744         new ReflectiveChannelFactory<>(NioSocketChannel.class), group);
745   }
746 
newTransport(ProtocolNegotiator negotiator, int maxMsgSize, int maxHeaderListSize, String userAgent, boolean enableKeepAlive, long keepAliveTimeNano, long keepAliveTimeoutNano, ChannelFactory<? extends Channel> channelFactory, EventLoopGroup group)747   private NettyClientTransport newTransport(ProtocolNegotiator negotiator, int maxMsgSize,
748       int maxHeaderListSize, String userAgent, boolean enableKeepAlive, long keepAliveTimeNano,
749       long keepAliveTimeoutNano, ChannelFactory<? extends Channel> channelFactory,
750       EventLoopGroup group) {
751     if (!enableKeepAlive) {
752       keepAliveTimeNano = KEEPALIVE_TIME_NANOS_DISABLED;
753     }
754     NettyClientTransport transport = new NettyClientTransport(
755         address, channelFactory, new HashMap<ChannelOption<?>, Object>(), group,
756         negotiator, false, DEFAULT_WINDOW_SIZE, maxMsgSize, maxHeaderListSize,
757         keepAliveTimeNano, keepAliveTimeoutNano,
758         false, authority, userAgent, tooManyPingsRunnable,
759         new TransportTracer(), eagAttributes, new SocketPicker(), new FakeChannelLogger(), false,
760         Ticker.systemTicker());
761     transports.add(transport);
762     return transport;
763   }
764 
startServer()765   private void startServer() throws IOException {
766     startServer(100, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE);
767   }
768 
startServer(int maxStreamsPerConnection, int maxHeaderListSize)769   private void startServer(int maxStreamsPerConnection, int maxHeaderListSize) throws IOException {
770     server = new NettyServer(
771         TestUtils.testServerAddresses(new InetSocketAddress(0)),
772         new ReflectiveChannelFactory<>(NioServerSocketChannel.class),
773         new HashMap<ChannelOption<?>, Object>(),
774         new HashMap<ChannelOption<?>, Object>(),
775         new FixedObjectPool<>(group), new FixedObjectPool<>(group), false, negotiator,
776         Collections.<ServerStreamTracer.Factory>emptyList(),
777         TransportTracer.getDefaultFactory(),
778         maxStreamsPerConnection,
779         false,
780         DEFAULT_WINDOW_SIZE, DEFAULT_MAX_MESSAGE_SIZE, maxHeaderListSize,
781         DEFAULT_SERVER_KEEPALIVE_TIME_NANOS, DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS,
782         MAX_CONNECTION_IDLE_NANOS_DISABLED,
783         MAX_CONNECTION_AGE_NANOS_DISABLED, MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE, true, 0,
784         Attributes.EMPTY,
785         channelz);
786     server.start(serverListener);
787     address = TestUtils.testServerAddress((InetSocketAddress) server.getListenSocketAddress());
788     authority = GrpcUtil.authorityFromHostAndPort(address.getHostString(), address.getPort());
789   }
790 
callMeMaybe(Runnable r)791   private void callMeMaybe(Runnable r) {
792     if (r != null) {
793       r.run();
794     }
795   }
796 
createSslContext()797   private static SslContext createSslContext() {
798     try {
799       InputStream serverCert = TlsTesting.loadCert("server1.pem");
800       InputStream key = TlsTesting.loadCert("server1.key");
801       return GrpcSslContexts.forServer(serverCert, key).build();
802     } catch (IOException ex) {
803       throw new RuntimeException(ex);
804     }
805   }
806 
807   private static class Rpc {
808     static final String MESSAGE = "hello";
809     static final MethodDescriptor<String, String> METHOD =
810         MethodDescriptor.<String, String>newBuilder()
811             .setType(MethodDescriptor.MethodType.UNARY)
812             .setFullMethodName("testService/test")
813             .setRequestMarshaller(StringMarshaller.INSTANCE)
814             .setResponseMarshaller(StringMarshaller.INSTANCE)
815             .build();
816 
817     final ClientStream stream;
818     final TestClientStreamListener listener = new TestClientStreamListener();
819 
Rpc(NettyClientTransport transport)820     Rpc(NettyClientTransport transport) {
821       this(transport, new Metadata());
822     }
823 
Rpc(NettyClientTransport transport, Metadata headers)824     Rpc(NettyClientTransport transport, Metadata headers) {
825       stream = transport.newStream(
826           METHOD, headers, CallOptions.DEFAULT,
827           new ClientStreamTracer[]{ new ClientStreamTracer() {} });
828       stream.start(listener);
829       stream.request(1);
830       stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8)));
831       stream.flush();
832     }
833 
halfClose()834     Rpc halfClose() {
835       stream.halfClose();
836       return this;
837     }
838 
waitForResponse()839     void waitForResponse() throws InterruptedException, ExecutionException, TimeoutException {
840       listener.responseFuture.get(10, TimeUnit.SECONDS);
841     }
842 
waitForClose()843     void waitForClose() throws InterruptedException, ExecutionException, TimeoutException {
844       listener.closedFuture.get(10, TimeUnit.SECONDS);
845     }
846   }
847 
848   private static final class TestClientStreamListener implements ClientStreamListener {
849     final SettableFuture<Void> closedFuture = SettableFuture.create();
850     final SettableFuture<Void> responseFuture = SettableFuture.create();
851 
852     @Override
headersRead(Metadata headers)853     public void headersRead(Metadata headers) {
854     }
855 
856     @Override
closed(Status status, RpcProgress rpcProgress, Metadata trailers)857     public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) {
858       if (status.isOk()) {
859         closedFuture.set(null);
860       } else {
861         StatusException e = status.asException();
862         closedFuture.setException(e);
863         responseFuture.setException(e);
864       }
865     }
866 
867     @Override
messagesAvailable(MessageProducer producer)868     public void messagesAvailable(MessageProducer producer) {
869       if (producer.next() != null) {
870         responseFuture.set(null);
871       }
872     }
873 
874     @Override
onReady()875     public void onReady() {
876     }
877   }
878 
879   private static final class EchoServerStreamListener implements ServerStreamListener {
880     final ServerStream stream;
881     final Metadata headers;
882 
EchoServerStreamListener(ServerStream stream, Metadata headers)883     EchoServerStreamListener(ServerStream stream, Metadata headers) {
884       this.stream = stream;
885       this.headers = headers;
886     }
887 
888     @Override
messagesAvailable(MessageProducer producer)889     public void messagesAvailable(MessageProducer producer) {
890       InputStream message;
891       while ((message = producer.next()) != null) {
892         // Just echo back the message.
893         stream.writeMessage(message);
894         stream.flush();
895       }
896     }
897 
898     @Override
onReady()899     public void onReady() {
900     }
901 
902     @Override
halfClosed()903     public void halfClosed() {
904       // Just close when the client closes.
905       stream.close(Status.OK, new Metadata());
906     }
907 
908     @Override
closed(Status status)909     public void closed(Status status) {
910     }
911   }
912 
913   private final class EchoServerListener implements ServerListener {
914     final List<NettyServerTransport> transports = new ArrayList<>();
915     final List<EchoServerStreamListener> streamListeners =
916             Collections.synchronizedList(new ArrayList<EchoServerStreamListener>());
917 
918     @Override
transportCreated(final ServerTransport transport)919     public ServerTransportListener transportCreated(final ServerTransport transport) {
920       transports.add((NettyServerTransport) transport);
921       return new ServerTransportListener() {
922         @Override
923         public void streamCreated(ServerStream stream, String method, Metadata headers) {
924           EchoServerStreamListener listener = new EchoServerStreamListener(stream, headers);
925           stream.setListener(listener);
926           stream.writeHeaders(new Metadata());
927           stream.request(1);
928           streamListeners.add(listener);
929         }
930 
931         @Override
932         public Attributes transportReady(Attributes transportAttrs) {
933           serverTransportAttributesList.add(transportAttrs);
934           return transportAttrs;
935         }
936 
937         @Override
938         public void transportTerminated() {}
939       };
940     }
941 
942     @Override
serverShutdown()943     public void serverShutdown() {
944     }
945   }
946 
947   private static final class StringMarshaller implements Marshaller<String> {
948     static final StringMarshaller INSTANCE = new StringMarshaller();
949 
950     @Override
951     public InputStream stream(String value) {
952       return new ByteArrayInputStream(value.getBytes(UTF_8));
953     }
954 
955     @Override
956     public String parse(InputStream stream) {
957       try {
958         return new String(ByteStreams.toByteArray(stream), UTF_8);
959       } catch (IOException ex) {
960         throw new RuntimeException(ex);
961       }
962     }
963   }
964 
965   private static class NoopHandler extends ChannelDuplexHandler {
966 
967     private final GrpcHttp2ConnectionHandler grpcHandler;
968 
969     public NoopHandler(GrpcHttp2ConnectionHandler grpcHandler) {
970       this.grpcHandler = grpcHandler;
971     }
972 
973     @Override
974     public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
975       ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
976     }
977 
978     public void fail(ChannelHandlerContext ctx, Throwable cause) {
979       ctx.fireExceptionCaught(cause);
980     }
981   }
982 
983   private static class NoopProtocolNegotiator implements ProtocolNegotiator {
984     GrpcHttp2ConnectionHandler grpcHandler;
985     NoopHandler handler;
986 
987     @Override
988     public ChannelHandler newHandler(final GrpcHttp2ConnectionHandler grpcHandler) {
989       this.grpcHandler = grpcHandler;
990       return handler = new NoopHandler(grpcHandler);
991     }
992 
993     @Override
994     public AsciiString scheme() {
995       return Utils.HTTP;
996     }
997 
998     @Override
999     public void close() {}
1000   }
1001 
1002   private static final class SocketPicker extends LocalSocketPicker {
1003 
1004     @Nullable
1005     @Override
1006     public SocketAddress createSocketAddress(SocketAddress remoteAddress, Attributes attrs) {
1007       return null;
1008     }
1009   }
1010 
1011   private static final class FakeChannelLogger extends ChannelLogger {
1012 
1013     @Override
1014     public void log(ChannelLogLevel level, String message) {}
1015 
1016     @Override
1017     public void log(ChannelLogLevel level, String messageFormat, Object... args) {}
1018   }
1019 }
1020