1 /* 2 * Copyright 2014 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.netty; 18 19 import static com.google.common.base.Charsets.UTF_8; 20 import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; 21 import static org.junit.Assert.assertEquals; 22 import static org.mockito.AdditionalAnswers.delegatesTo; 23 import static org.mockito.Matchers.any; 24 import static org.mockito.Matchers.anyLong; 25 import static org.mockito.Mockito.atLeastOnce; 26 import static org.mockito.Mockito.doAnswer; 27 import static org.mockito.Mockito.mock; 28 import static org.mockito.Mockito.verify; 29 import static org.mockito.Mockito.when; 30 31 import com.google.errorprone.annotations.CanIgnoreReturnValue; 32 import io.grpc.InternalChannelz.TransportStats; 33 import io.grpc.internal.FakeClock; 34 import io.grpc.internal.MessageFramer; 35 import io.grpc.internal.StatsTraceContext; 36 import io.grpc.internal.TransportTracer; 37 import io.grpc.internal.WritableBuffer; 38 import io.netty.buffer.ByteBuf; 39 import io.netty.buffer.ByteBufAllocator; 40 import io.netty.buffer.ByteBufUtil; 41 import io.netty.buffer.CompositeByteBuf; 42 import io.netty.buffer.Unpooled; 43 import io.netty.buffer.UnpooledByteBufAllocator; 44 import io.netty.channel.ChannelFuture; 45 import io.netty.channel.ChannelHandler; 46 import io.netty.channel.ChannelHandlerContext; 47 import io.netty.channel.ChannelPromise; 48 import io.netty.channel.EventLoop; 49 import io.netty.channel.embedded.EmbeddedChannel; 50 import io.netty.handler.codec.http2.DefaultHttp2FrameReader; 51 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; 52 import io.netty.handler.codec.http2.Http2CodecUtil; 53 import io.netty.handler.codec.http2.Http2Connection; 54 import io.netty.handler.codec.http2.Http2ConnectionHandler; 55 import io.netty.handler.codec.http2.Http2Exception; 56 import io.netty.handler.codec.http2.Http2FrameReader; 57 import io.netty.handler.codec.http2.Http2FrameWriter; 58 import io.netty.handler.codec.http2.Http2Headers; 59 import io.netty.handler.codec.http2.Http2HeadersDecoder; 60 import io.netty.handler.codec.http2.Http2LocalFlowController; 61 import io.netty.handler.codec.http2.Http2Settings; 62 import io.netty.handler.codec.http2.Http2Stream; 63 import io.netty.util.concurrent.DefaultPromise; 64 import io.netty.util.concurrent.Promise; 65 import io.netty.util.concurrent.ScheduledFuture; 66 import java.io.ByteArrayInputStream; 67 import java.util.concurrent.Delayed; 68 import java.util.concurrent.TimeUnit; 69 import org.junit.Test; 70 import org.junit.runner.RunWith; 71 import org.junit.runners.JUnit4; 72 import org.mockito.ArgumentCaptor; 73 import org.mockito.invocation.InvocationOnMock; 74 import org.mockito.stubbing.Answer; 75 import org.mockito.verification.VerificationMode; 76 77 /** 78 * Base class for Netty handler unit tests. 79 */ 80 @RunWith(JUnit4.class) 81 public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> { 82 83 private ByteBuf content; 84 85 private EmbeddedChannel channel; 86 87 private ChannelHandlerContext ctx; 88 89 private Http2FrameWriter frameWriter; 90 91 private Http2FrameReader frameReader; 92 93 private T handler; 94 95 private WriteQueue writeQueue; 96 97 /** 98 * Does additional setup jobs. Call it manually when necessary. 99 */ manualSetUp()100 protected void manualSetUp() throws Exception {} 101 102 protected final TransportTracer transportTracer = new TransportTracer(); 103 protected int flowControlWindow = DEFAULT_WINDOW_SIZE; 104 105 private final FakeClock fakeClock = new FakeClock(); 106 fakeClock()107 FakeClock fakeClock() { 108 return fakeClock; 109 } 110 111 /** 112 * Must be called by subclasses to initialize the handler and channel. 113 */ initChannel(Http2HeadersDecoder headersDecoder)114 protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception { 115 content = Unpooled.copiedBuffer("hello world", UTF_8); 116 frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter())); 117 frameReader = new DefaultHttp2FrameReader(headersDecoder); 118 119 channel = new FakeClockSupportedChanel(); 120 handler = newHandler(); 121 channel.pipeline().addLast(handler); 122 ctx = channel.pipeline().context(handler); 123 124 writeQueue = initWriteQueue(); 125 } 126 127 private final class FakeClockSupportedChanel extends EmbeddedChannel { 128 EventLoop eventLoop; 129 FakeClockSupportedChanel(ChannelHandler... handlers)130 FakeClockSupportedChanel(ChannelHandler... handlers) { 131 super(handlers); 132 } 133 134 @Override eventLoop()135 public EventLoop eventLoop() { 136 if (eventLoop == null) { 137 createEventLoop(); 138 } 139 return eventLoop; 140 } 141 createEventLoop()142 void createEventLoop() { 143 EventLoop realEventLoop = super.eventLoop(); 144 if (realEventLoop == null) { 145 return; 146 } 147 eventLoop = mock(EventLoop.class, delegatesTo(realEventLoop)); 148 doAnswer( 149 new Answer<ScheduledFuture<Void>>() { 150 @Override 151 public ScheduledFuture<Void> answer(InvocationOnMock invocation) throws Throwable { 152 Runnable command = (Runnable) invocation.getArguments()[0]; 153 Long delay = (Long) invocation.getArguments()[1]; 154 TimeUnit timeUnit = (TimeUnit) invocation.getArguments()[2]; 155 return new FakeClockScheduledNettyFuture(eventLoop, command, delay, timeUnit); 156 } 157 }).when(eventLoop).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); 158 } 159 } 160 161 private final class FakeClockScheduledNettyFuture extends DefaultPromise<Void> 162 implements ScheduledFuture<Void> { 163 final java.util.concurrent.ScheduledFuture<?> future; 164 FakeClockScheduledNettyFuture( EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit)165 FakeClockScheduledNettyFuture( 166 EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit) { 167 super(eventLoop); 168 Runnable wrap = new Runnable() { 169 @Override 170 public void run() { 171 try { 172 command.run(); 173 } catch (Throwable t) { 174 setFailure(t); 175 return; 176 } 177 if (!isDone()) { 178 Promise<Void> unused = setSuccess(null); 179 } 180 // else: The command itself, such as a shutdown task, might have cancelled all the 181 // scheduled tasks already. 182 } 183 }; 184 future = fakeClock.getScheduledExecutorService().schedule(wrap, delay, timeUnit); 185 } 186 187 @Override cancel(boolean mayInterruptIfRunning)188 public boolean cancel(boolean mayInterruptIfRunning) { 189 if (future.cancel(mayInterruptIfRunning)) { 190 return super.cancel(mayInterruptIfRunning); 191 } 192 return false; 193 } 194 195 @Override getDelay(TimeUnit unit)196 public long getDelay(TimeUnit unit) { 197 return Math.max(future.getDelay(unit), 1L); // never return zero or negative delay. 198 } 199 200 @Override compareTo(Delayed o)201 public int compareTo(Delayed o) { 202 return future.compareTo(o); 203 } 204 } 205 handler()206 protected final T handler() { 207 return handler; 208 } 209 channel()210 protected final EmbeddedChannel channel() { 211 return channel; 212 } 213 ctx()214 protected final ChannelHandlerContext ctx() { 215 return ctx; 216 } 217 frameWriter()218 protected final Http2FrameWriter frameWriter() { 219 return frameWriter; 220 } 221 frameReader()222 protected final Http2FrameReader frameReader() { 223 return frameReader; 224 } 225 content()226 protected final ByteBuf content() { 227 return content; 228 } 229 contentAsArray()230 protected final byte[] contentAsArray() { 231 return ByteBufUtil.getBytes(content()); 232 } 233 verifyWrite()234 protected final Http2FrameWriter verifyWrite() { 235 return verify(frameWriter); 236 } 237 verifyWrite(VerificationMode verificationMode)238 protected final Http2FrameWriter verifyWrite(VerificationMode verificationMode) { 239 return verify(frameWriter, verificationMode); 240 } 241 channelRead(Object obj)242 protected final void channelRead(Object obj) throws Exception { 243 channel.writeInbound(obj); 244 } 245 grpcDataFrame(int streamId, boolean endStream, byte[] content)246 protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) { 247 final ByteBuf compressionFrame = Unpooled.buffer(content.length); 248 MessageFramer framer = new MessageFramer( 249 new MessageFramer.Sink() { 250 @Override 251 public void deliverFrame( 252 WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) { 253 if (frame != null) { 254 ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf(); 255 compressionFrame.writeBytes(bytebuf); 256 } 257 } 258 }, 259 new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT), 260 StatsTraceContext.NOOP); 261 framer.writePayload(new ByteArrayInputStream(content)); 262 framer.flush(); 263 ChannelHandlerContext ctx = newMockContext(); 264 new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream, 265 newPromise()); 266 return captureWrite(ctx); 267 } 268 dataFrame(int streamId, boolean endStream, ByteBuf content)269 protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) { 270 // Need to retain the content since the frameWriter releases it. 271 content.retain(); 272 273 ChannelHandlerContext ctx = newMockContext(); 274 new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise()); 275 return captureWrite(ctx); 276 } 277 pingFrame(boolean ack, long payload)278 protected final ByteBuf pingFrame(boolean ack, long payload) { 279 ChannelHandlerContext ctx = newMockContext(); 280 new DefaultHttp2FrameWriter().writePing(ctx, ack, payload, newPromise()); 281 return captureWrite(ctx); 282 } 283 headersFrame(int streamId, Http2Headers headers)284 protected final ByteBuf headersFrame(int streamId, Http2Headers headers) { 285 ChannelHandlerContext ctx = newMockContext(); 286 new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, false, newPromise()); 287 return captureWrite(ctx); 288 } 289 goAwayFrame(int lastStreamId)290 protected final ByteBuf goAwayFrame(int lastStreamId) { 291 return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER); 292 } 293 goAwayFrame(int lastStreamId, int errorCode, ByteBuf data)294 protected final ByteBuf goAwayFrame(int lastStreamId, int errorCode, ByteBuf data) { 295 ChannelHandlerContext ctx = newMockContext(); 296 new DefaultHttp2FrameWriter().writeGoAway(ctx, lastStreamId, errorCode, data, newPromise()); 297 return captureWrite(ctx); 298 } 299 rstStreamFrame(int streamId, int errorCode)300 protected final ByteBuf rstStreamFrame(int streamId, int errorCode) { 301 ChannelHandlerContext ctx = newMockContext(); 302 new DefaultHttp2FrameWriter().writeRstStream(ctx, streamId, errorCode, newPromise()); 303 return captureWrite(ctx); 304 } 305 serializeSettings(Http2Settings settings)306 protected final ByteBuf serializeSettings(Http2Settings settings) { 307 ChannelHandlerContext ctx = newMockContext(); 308 new DefaultHttp2FrameWriter().writeSettings(ctx, settings, newPromise()); 309 return captureWrite(ctx); 310 } 311 windowUpdate(int streamId, int delta)312 protected final ByteBuf windowUpdate(int streamId, int delta) { 313 ChannelHandlerContext ctx = newMockContext(); 314 new DefaultHttp2FrameWriter().writeWindowUpdate(ctx, 0, delta, newPromise()); 315 return captureWrite(ctx); 316 } 317 newPromise()318 protected final ChannelPromise newPromise() { 319 return channel.newPromise(); 320 } 321 connection()322 protected final Http2Connection connection() { 323 return handler().connection(); 324 } 325 326 @CanIgnoreReturnValue enqueue(WriteQueue.QueuedCommand command)327 protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) { 328 ChannelFuture future = writeQueue.enqueue(command, true); 329 channel.runPendingTasks(); 330 return future; 331 } 332 newMockContext()333 protected final ChannelHandlerContext newMockContext() { 334 ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); 335 when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); 336 EventLoop eventLoop = mock(EventLoop.class); 337 when(ctx.executor()).thenReturn(eventLoop); 338 when(ctx.channel()).thenReturn(channel); 339 return ctx; 340 } 341 captureWrite(ChannelHandlerContext ctx)342 protected final ByteBuf captureWrite(ChannelHandlerContext ctx) { 343 ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class); 344 verify(ctx, atLeastOnce()).write(captor.capture(), any(ChannelPromise.class)); 345 CompositeByteBuf composite = Unpooled.compositeBuffer(); 346 for (ByteBuf buf : captor.getAllValues()) { 347 composite.addComponent(buf); 348 composite.writerIndex(composite.writerIndex() + buf.readableBytes()); 349 } 350 return composite; 351 } 352 newHandler()353 protected abstract T newHandler() throws Http2Exception; 354 initWriteQueue()355 protected abstract WriteQueue initWriteQueue(); 356 makeStream()357 protected abstract void makeStream() throws Exception; 358 359 @Test dataPingSentOnHeaderRecieved()360 public void dataPingSentOnHeaderRecieved() throws Exception { 361 manualSetUp(); 362 makeStream(); 363 AbstractNettyHandler handler = (AbstractNettyHandler) handler(); 364 handler.setAutoTuneFlowControl(true); 365 366 channelRead(dataFrame(3, false, content())); 367 368 assertEquals(1, handler.flowControlPing().getPingCount()); 369 } 370 371 @Test dataPingAckIsRecognized()372 public void dataPingAckIsRecognized() throws Exception { 373 manualSetUp(); 374 makeStream(); 375 AbstractNettyHandler handler = (AbstractNettyHandler) handler(); 376 handler.setAutoTuneFlowControl(true); 377 378 channelRead(dataFrame(3, false, content())); 379 long pingData = handler.flowControlPing().payload(); 380 channelRead(pingFrame(true, pingData)); 381 382 assertEquals(1, handler.flowControlPing().getPingCount()); 383 assertEquals(1, handler.flowControlPing().getPingReturn()); 384 } 385 386 @Test dataSizeSincePingAccumulates()387 public void dataSizeSincePingAccumulates() throws Exception { 388 manualSetUp(); 389 makeStream(); 390 AbstractNettyHandler handler = (AbstractNettyHandler) handler(); 391 handler.setAutoTuneFlowControl(true); 392 long frameData = 123456; 393 ByteBuf buff = ctx().alloc().buffer(16); 394 buff.writeLong(frameData); 395 int length = buff.readableBytes(); 396 397 channelRead(dataFrame(3, false, buff.copy())); 398 channelRead(dataFrame(3, false, buff.copy())); 399 channelRead(dataFrame(3, false, buff.copy())); 400 401 assertEquals(length * 3, handler.flowControlPing().getDataSincePing()); 402 } 403 404 @Test windowUpdateMatchesTarget()405 public void windowUpdateMatchesTarget() throws Exception { 406 manualSetUp(); 407 Http2Stream connectionStream = connection().connectionStream(); 408 Http2LocalFlowController localFlowController = connection().local().flowController(); 409 makeStream(); 410 AbstractNettyHandler handler = (AbstractNettyHandler) handler(); 411 handler.setAutoTuneFlowControl(true); 412 413 ByteBuf data = ctx().alloc().buffer(1024); 414 while (data.isWritable()) { 415 data.writeLong(1111); 416 } 417 int length = data.readableBytes(); 418 ByteBuf frame = dataFrame(3, false, data.copy()); 419 channelRead(frame); 420 int accumulator = length; 421 // 40 is arbitrary, any number large enough to trigger a window update would work 422 for (int i = 0; i < 40; i++) { 423 channelRead(dataFrame(3, false, data.copy())); 424 accumulator += length; 425 } 426 long pingData = handler.flowControlPing().payload(); 427 channelRead(pingFrame(true, pingData)); 428 429 assertEquals(accumulator, handler.flowControlPing().getDataSincePing()); 430 assertEquals(2 * accumulator, localFlowController.initialWindowSize(connectionStream)); 431 } 432 433 @Test windowShouldNotExceedMaxWindowSize()434 public void windowShouldNotExceedMaxWindowSize() throws Exception { 435 manualSetUp(); 436 makeStream(); 437 AbstractNettyHandler handler = (AbstractNettyHandler) handler(); 438 handler.setAutoTuneFlowControl(true); 439 Http2Stream connectionStream = connection().connectionStream(); 440 Http2LocalFlowController localFlowController = connection().local().flowController(); 441 int maxWindow = handler.flowControlPing().maxWindow(); 442 443 handler.flowControlPing().setDataSizeSincePing(maxWindow); 444 long payload = handler.flowControlPing().payload(); 445 channelRead(pingFrame(true, payload)); 446 447 assertEquals(maxWindow, localFlowController.initialWindowSize(connectionStream)); 448 } 449 450 @Test transportTracer_windowSizeDefault()451 public void transportTracer_windowSizeDefault() throws Exception { 452 manualSetUp(); 453 TransportStats transportStats = transportTracer.getStats(); 454 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); 455 assertEquals(flowControlWindow, transportStats.localFlowControlWindow); 456 } 457 458 @Test transportTracer_windowSize()459 public void transportTracer_windowSize() throws Exception { 460 flowControlWindow = 1024 * 1024; 461 manualSetUp(); 462 TransportStats transportStats = transportTracer.getStats(); 463 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow); 464 assertEquals(flowControlWindow, transportStats.localFlowControlWindow); 465 } 466 467 @Test transportTracer_windowUpdate_remote()468 public void transportTracer_windowUpdate_remote() throws Exception { 469 manualSetUp(); 470 TransportStats before = transportTracer.getStats(); 471 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); 472 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow); 473 474 ByteBuf serializedSettings = windowUpdate(0, 1000); 475 channelRead(serializedSettings); 476 TransportStats after = transportTracer.getStats(); 477 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE + 1000, 478 after.remoteFlowControlWindow); 479 assertEquals(flowControlWindow, after.localFlowControlWindow); 480 } 481 482 @Test transportTracer_windowUpdate_local()483 public void transportTracer_windowUpdate_local() throws Exception { 484 manualSetUp(); 485 TransportStats before = transportTracer.getStats(); 486 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow); 487 assertEquals(flowControlWindow, before.localFlowControlWindow); 488 489 // If the window size is below a certain threshold, netty will wait to apply the update. 490 // Use a large increment to be sure that it exceeds the threshold. 491 connection().local().flowController().incrementWindowSize( 492 connection().connectionStream(), 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE); 493 494 TransportStats after = transportTracer.getStats(); 495 assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.remoteFlowControlWindow); 496 assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE, 497 connection().local().flowController().windowSize(connection().connectionStream())); 498 } 499 } 500