• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2014 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.netty;
18 
19 import static com.google.common.base.Charsets.UTF_8;
20 import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE;
21 import static org.junit.Assert.assertEquals;
22 import static org.mockito.AdditionalAnswers.delegatesTo;
23 import static org.mockito.Matchers.any;
24 import static org.mockito.Matchers.anyLong;
25 import static org.mockito.Mockito.atLeastOnce;
26 import static org.mockito.Mockito.doAnswer;
27 import static org.mockito.Mockito.mock;
28 import static org.mockito.Mockito.verify;
29 import static org.mockito.Mockito.when;
30 
31 import com.google.errorprone.annotations.CanIgnoreReturnValue;
32 import io.grpc.InternalChannelz.TransportStats;
33 import io.grpc.internal.FakeClock;
34 import io.grpc.internal.MessageFramer;
35 import io.grpc.internal.StatsTraceContext;
36 import io.grpc.internal.TransportTracer;
37 import io.grpc.internal.WritableBuffer;
38 import io.netty.buffer.ByteBuf;
39 import io.netty.buffer.ByteBufAllocator;
40 import io.netty.buffer.ByteBufUtil;
41 import io.netty.buffer.CompositeByteBuf;
42 import io.netty.buffer.Unpooled;
43 import io.netty.buffer.UnpooledByteBufAllocator;
44 import io.netty.channel.ChannelFuture;
45 import io.netty.channel.ChannelHandler;
46 import io.netty.channel.ChannelHandlerContext;
47 import io.netty.channel.ChannelPromise;
48 import io.netty.channel.EventLoop;
49 import io.netty.channel.embedded.EmbeddedChannel;
50 import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
51 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
52 import io.netty.handler.codec.http2.Http2CodecUtil;
53 import io.netty.handler.codec.http2.Http2Connection;
54 import io.netty.handler.codec.http2.Http2ConnectionHandler;
55 import io.netty.handler.codec.http2.Http2Exception;
56 import io.netty.handler.codec.http2.Http2FrameReader;
57 import io.netty.handler.codec.http2.Http2FrameWriter;
58 import io.netty.handler.codec.http2.Http2Headers;
59 import io.netty.handler.codec.http2.Http2HeadersDecoder;
60 import io.netty.handler.codec.http2.Http2LocalFlowController;
61 import io.netty.handler.codec.http2.Http2Settings;
62 import io.netty.handler.codec.http2.Http2Stream;
63 import io.netty.util.concurrent.DefaultPromise;
64 import io.netty.util.concurrent.Promise;
65 import io.netty.util.concurrent.ScheduledFuture;
66 import java.io.ByteArrayInputStream;
67 import java.util.concurrent.Delayed;
68 import java.util.concurrent.TimeUnit;
69 import org.junit.Test;
70 import org.junit.runner.RunWith;
71 import org.junit.runners.JUnit4;
72 import org.mockito.ArgumentCaptor;
73 import org.mockito.invocation.InvocationOnMock;
74 import org.mockito.stubbing.Answer;
75 import org.mockito.verification.VerificationMode;
76 
77 /**
78  * Base class for Netty handler unit tests.
79  */
80 @RunWith(JUnit4.class)
81 public abstract class NettyHandlerTestBase<T extends Http2ConnectionHandler> {
82 
83   private ByteBuf content;
84 
85   private EmbeddedChannel channel;
86 
87   private ChannelHandlerContext ctx;
88 
89   private Http2FrameWriter frameWriter;
90 
91   private Http2FrameReader frameReader;
92 
93   private T handler;
94 
95   private WriteQueue writeQueue;
96 
97   /**
98    * Does additional setup jobs. Call it manually when necessary.
99    */
manualSetUp()100   protected void manualSetUp() throws Exception {}
101 
102   protected final TransportTracer transportTracer = new TransportTracer();
103   protected int flowControlWindow = DEFAULT_WINDOW_SIZE;
104 
105   private final FakeClock fakeClock = new FakeClock();
106 
fakeClock()107   FakeClock fakeClock() {
108     return fakeClock;
109   }
110 
111   /**
112    * Must be called by subclasses to initialize the handler and channel.
113    */
initChannel(Http2HeadersDecoder headersDecoder)114   protected final void initChannel(Http2HeadersDecoder headersDecoder) throws Exception {
115     content = Unpooled.copiedBuffer("hello world", UTF_8);
116     frameWriter = mock(Http2FrameWriter.class, delegatesTo(new DefaultHttp2FrameWriter()));
117     frameReader = new DefaultHttp2FrameReader(headersDecoder);
118 
119     channel = new FakeClockSupportedChanel();
120     handler = newHandler();
121     channel.pipeline().addLast(handler);
122     ctx = channel.pipeline().context(handler);
123 
124     writeQueue = initWriteQueue();
125   }
126 
127   private final class FakeClockSupportedChanel extends EmbeddedChannel {
128     EventLoop eventLoop;
129 
FakeClockSupportedChanel(ChannelHandler... handlers)130     FakeClockSupportedChanel(ChannelHandler... handlers) {
131       super(handlers);
132     }
133 
134     @Override
eventLoop()135     public EventLoop eventLoop() {
136       if (eventLoop == null) {
137         createEventLoop();
138       }
139       return eventLoop;
140     }
141 
createEventLoop()142     void createEventLoop() {
143       EventLoop realEventLoop = super.eventLoop();
144       if (realEventLoop == null) {
145         return;
146       }
147       eventLoop = mock(EventLoop.class, delegatesTo(realEventLoop));
148       doAnswer(
149           new Answer<ScheduledFuture<Void>>() {
150             @Override
151             public ScheduledFuture<Void> answer(InvocationOnMock invocation) throws Throwable {
152               Runnable command = (Runnable) invocation.getArguments()[0];
153               Long delay = (Long) invocation.getArguments()[1];
154               TimeUnit timeUnit = (TimeUnit) invocation.getArguments()[2];
155               return new FakeClockScheduledNettyFuture(eventLoop, command, delay, timeUnit);
156             }
157           }).when(eventLoop).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class));
158     }
159   }
160 
161   private final class FakeClockScheduledNettyFuture extends DefaultPromise<Void>
162       implements ScheduledFuture<Void> {
163     final java.util.concurrent.ScheduledFuture<?> future;
164 
FakeClockScheduledNettyFuture( EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit)165     FakeClockScheduledNettyFuture(
166         EventLoop eventLoop, final Runnable command, long delay, TimeUnit timeUnit) {
167       super(eventLoop);
168       Runnable wrap = new Runnable() {
169         @Override
170         public void run() {
171           try {
172             command.run();
173           } catch (Throwable t) {
174             setFailure(t);
175             return;
176           }
177           if (!isDone()) {
178             Promise<Void> unused = setSuccess(null);
179           }
180           // else: The command itself, such as a shutdown task, might have cancelled all the
181           // scheduled tasks already.
182         }
183       };
184       future = fakeClock.getScheduledExecutorService().schedule(wrap, delay, timeUnit);
185     }
186 
187     @Override
cancel(boolean mayInterruptIfRunning)188     public boolean cancel(boolean mayInterruptIfRunning) {
189       if (future.cancel(mayInterruptIfRunning)) {
190         return super.cancel(mayInterruptIfRunning);
191       }
192       return false;
193     }
194 
195     @Override
getDelay(TimeUnit unit)196     public long getDelay(TimeUnit unit) {
197       return Math.max(future.getDelay(unit), 1L); // never return zero or negative delay.
198     }
199 
200     @Override
compareTo(Delayed o)201     public int compareTo(Delayed o) {
202       return future.compareTo(o);
203     }
204   }
205 
handler()206   protected final T handler() {
207     return handler;
208   }
209 
channel()210   protected final EmbeddedChannel channel() {
211     return channel;
212   }
213 
ctx()214   protected final ChannelHandlerContext ctx() {
215     return ctx;
216   }
217 
frameWriter()218   protected final Http2FrameWriter frameWriter() {
219     return frameWriter;
220   }
221 
frameReader()222   protected final Http2FrameReader frameReader() {
223     return frameReader;
224   }
225 
content()226   protected final ByteBuf content() {
227     return content;
228   }
229 
contentAsArray()230   protected final byte[] contentAsArray() {
231     return ByteBufUtil.getBytes(content());
232   }
233 
verifyWrite()234   protected final Http2FrameWriter verifyWrite() {
235     return verify(frameWriter);
236   }
237 
verifyWrite(VerificationMode verificationMode)238   protected final Http2FrameWriter verifyWrite(VerificationMode verificationMode) {
239     return verify(frameWriter, verificationMode);
240   }
241 
channelRead(Object obj)242   protected final void channelRead(Object obj) throws Exception {
243     channel.writeInbound(obj);
244   }
245 
grpcDataFrame(int streamId, boolean endStream, byte[] content)246   protected ByteBuf grpcDataFrame(int streamId, boolean endStream, byte[] content) {
247     final ByteBuf compressionFrame = Unpooled.buffer(content.length);
248     MessageFramer framer = new MessageFramer(
249         new MessageFramer.Sink() {
250           @Override
251           public void deliverFrame(
252               WritableBuffer frame, boolean endOfStream, boolean flush, int numMessages) {
253             if (frame != null) {
254               ByteBuf bytebuf = ((NettyWritableBuffer) frame).bytebuf();
255               compressionFrame.writeBytes(bytebuf);
256             }
257           }
258         },
259         new NettyWritableBufferAllocator(ByteBufAllocator.DEFAULT),
260         StatsTraceContext.NOOP);
261     framer.writePayload(new ByteArrayInputStream(content));
262     framer.flush();
263     ChannelHandlerContext ctx = newMockContext();
264     new DefaultHttp2FrameWriter().writeData(ctx, streamId, compressionFrame, 0, endStream,
265         newPromise());
266     return captureWrite(ctx);
267   }
268 
dataFrame(int streamId, boolean endStream, ByteBuf content)269   protected final ByteBuf dataFrame(int streamId, boolean endStream, ByteBuf content) {
270     // Need to retain the content since the frameWriter releases it.
271     content.retain();
272 
273     ChannelHandlerContext ctx = newMockContext();
274     new DefaultHttp2FrameWriter().writeData(ctx, streamId, content, 0, endStream, newPromise());
275     return captureWrite(ctx);
276   }
277 
pingFrame(boolean ack, long payload)278   protected final ByteBuf pingFrame(boolean ack, long payload) {
279     ChannelHandlerContext ctx = newMockContext();
280     new DefaultHttp2FrameWriter().writePing(ctx, ack, payload, newPromise());
281     return captureWrite(ctx);
282   }
283 
headersFrame(int streamId, Http2Headers headers)284   protected final ByteBuf headersFrame(int streamId, Http2Headers headers) {
285     ChannelHandlerContext ctx = newMockContext();
286     new DefaultHttp2FrameWriter().writeHeaders(ctx, streamId, headers, 0, false, newPromise());
287     return captureWrite(ctx);
288   }
289 
goAwayFrame(int lastStreamId)290   protected final ByteBuf goAwayFrame(int lastStreamId) {
291     return goAwayFrame(lastStreamId, 0, Unpooled.EMPTY_BUFFER);
292   }
293 
goAwayFrame(int lastStreamId, int errorCode, ByteBuf data)294   protected final ByteBuf goAwayFrame(int lastStreamId, int errorCode, ByteBuf data) {
295     ChannelHandlerContext ctx = newMockContext();
296     new DefaultHttp2FrameWriter().writeGoAway(ctx, lastStreamId, errorCode, data, newPromise());
297     return captureWrite(ctx);
298   }
299 
rstStreamFrame(int streamId, int errorCode)300   protected final ByteBuf rstStreamFrame(int streamId, int errorCode) {
301     ChannelHandlerContext ctx = newMockContext();
302     new DefaultHttp2FrameWriter().writeRstStream(ctx, streamId, errorCode, newPromise());
303     return captureWrite(ctx);
304   }
305 
serializeSettings(Http2Settings settings)306   protected final ByteBuf serializeSettings(Http2Settings settings) {
307     ChannelHandlerContext ctx = newMockContext();
308     new DefaultHttp2FrameWriter().writeSettings(ctx, settings, newPromise());
309     return captureWrite(ctx);
310   }
311 
windowUpdate(int streamId, int delta)312   protected final ByteBuf windowUpdate(int streamId, int delta) {
313     ChannelHandlerContext ctx = newMockContext();
314     new DefaultHttp2FrameWriter().writeWindowUpdate(ctx, 0, delta, newPromise());
315     return captureWrite(ctx);
316   }
317 
newPromise()318   protected final ChannelPromise newPromise() {
319     return channel.newPromise();
320   }
321 
connection()322   protected final Http2Connection connection() {
323     return handler().connection();
324   }
325 
326   @CanIgnoreReturnValue
enqueue(WriteQueue.QueuedCommand command)327   protected final ChannelFuture enqueue(WriteQueue.QueuedCommand command) {
328     ChannelFuture future = writeQueue.enqueue(command, true);
329     channel.runPendingTasks();
330     return future;
331   }
332 
newMockContext()333   protected final ChannelHandlerContext newMockContext() {
334     ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
335     when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);
336     EventLoop eventLoop = mock(EventLoop.class);
337     when(ctx.executor()).thenReturn(eventLoop);
338     when(ctx.channel()).thenReturn(channel);
339     return ctx;
340   }
341 
captureWrite(ChannelHandlerContext ctx)342   protected final ByteBuf captureWrite(ChannelHandlerContext ctx) {
343     ArgumentCaptor<ByteBuf> captor = ArgumentCaptor.forClass(ByteBuf.class);
344     verify(ctx, atLeastOnce()).write(captor.capture(), any(ChannelPromise.class));
345     CompositeByteBuf composite = Unpooled.compositeBuffer();
346     for (ByteBuf buf : captor.getAllValues()) {
347       composite.addComponent(buf);
348       composite.writerIndex(composite.writerIndex() + buf.readableBytes());
349     }
350     return composite;
351   }
352 
newHandler()353   protected abstract T newHandler() throws Http2Exception;
354 
initWriteQueue()355   protected abstract WriteQueue initWriteQueue();
356 
makeStream()357   protected abstract void makeStream() throws Exception;
358 
359   @Test
dataPingSentOnHeaderRecieved()360   public void dataPingSentOnHeaderRecieved() throws Exception {
361     manualSetUp();
362     makeStream();
363     AbstractNettyHandler handler = (AbstractNettyHandler) handler();
364     handler.setAutoTuneFlowControl(true);
365 
366     channelRead(dataFrame(3, false, content()));
367 
368     assertEquals(1, handler.flowControlPing().getPingCount());
369   }
370 
371   @Test
dataPingAckIsRecognized()372   public void dataPingAckIsRecognized() throws Exception {
373     manualSetUp();
374     makeStream();
375     AbstractNettyHandler handler = (AbstractNettyHandler) handler();
376     handler.setAutoTuneFlowControl(true);
377 
378     channelRead(dataFrame(3, false, content()));
379     long pingData = handler.flowControlPing().payload();
380     channelRead(pingFrame(true, pingData));
381 
382     assertEquals(1, handler.flowControlPing().getPingCount());
383     assertEquals(1, handler.flowControlPing().getPingReturn());
384   }
385 
386   @Test
dataSizeSincePingAccumulates()387   public void dataSizeSincePingAccumulates() throws Exception {
388     manualSetUp();
389     makeStream();
390     AbstractNettyHandler handler = (AbstractNettyHandler) handler();
391     handler.setAutoTuneFlowControl(true);
392     long frameData = 123456;
393     ByteBuf buff = ctx().alloc().buffer(16);
394     buff.writeLong(frameData);
395     int length = buff.readableBytes();
396 
397     channelRead(dataFrame(3, false, buff.copy()));
398     channelRead(dataFrame(3, false, buff.copy()));
399     channelRead(dataFrame(3, false, buff.copy()));
400 
401     assertEquals(length * 3, handler.flowControlPing().getDataSincePing());
402   }
403 
404   @Test
windowUpdateMatchesTarget()405   public void windowUpdateMatchesTarget() throws Exception {
406     manualSetUp();
407     Http2Stream connectionStream = connection().connectionStream();
408     Http2LocalFlowController localFlowController = connection().local().flowController();
409     makeStream();
410     AbstractNettyHandler handler = (AbstractNettyHandler) handler();
411     handler.setAutoTuneFlowControl(true);
412 
413     ByteBuf data = ctx().alloc().buffer(1024);
414     while (data.isWritable()) {
415       data.writeLong(1111);
416     }
417     int length = data.readableBytes();
418     ByteBuf frame = dataFrame(3, false, data.copy());
419     channelRead(frame);
420     int accumulator = length;
421     // 40 is arbitrary, any number large enough to trigger a window update would work
422     for (int i = 0; i < 40; i++) {
423       channelRead(dataFrame(3, false, data.copy()));
424       accumulator += length;
425     }
426     long pingData = handler.flowControlPing().payload();
427     channelRead(pingFrame(true, pingData));
428 
429     assertEquals(accumulator, handler.flowControlPing().getDataSincePing());
430     assertEquals(2 * accumulator, localFlowController.initialWindowSize(connectionStream));
431   }
432 
433   @Test
windowShouldNotExceedMaxWindowSize()434   public void windowShouldNotExceedMaxWindowSize() throws Exception {
435     manualSetUp();
436     makeStream();
437     AbstractNettyHandler handler = (AbstractNettyHandler) handler();
438     handler.setAutoTuneFlowControl(true);
439     Http2Stream connectionStream = connection().connectionStream();
440     Http2LocalFlowController localFlowController = connection().local().flowController();
441     int maxWindow = handler.flowControlPing().maxWindow();
442 
443     handler.flowControlPing().setDataSizeSincePing(maxWindow);
444     long payload = handler.flowControlPing().payload();
445     channelRead(pingFrame(true, payload));
446 
447     assertEquals(maxWindow, localFlowController.initialWindowSize(connectionStream));
448   }
449 
450   @Test
transportTracer_windowSizeDefault()451   public void transportTracer_windowSizeDefault() throws Exception {
452     manualSetUp();
453     TransportStats transportStats = transportTracer.getStats();
454     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow);
455     assertEquals(flowControlWindow, transportStats.localFlowControlWindow);
456   }
457 
458   @Test
transportTracer_windowSize()459   public void transportTracer_windowSize() throws Exception {
460     flowControlWindow = 1024 * 1024;
461     manualSetUp();
462     TransportStats transportStats = transportTracer.getStats();
463     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, transportStats.remoteFlowControlWindow);
464     assertEquals(flowControlWindow, transportStats.localFlowControlWindow);
465   }
466 
467   @Test
transportTracer_windowUpdate_remote()468   public void transportTracer_windowUpdate_remote() throws Exception {
469     manualSetUp();
470     TransportStats before = transportTracer.getStats();
471     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow);
472     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.localFlowControlWindow);
473 
474     ByteBuf serializedSettings = windowUpdate(0, 1000);
475     channelRead(serializedSettings);
476     TransportStats after = transportTracer.getStats();
477     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE + 1000,
478         after.remoteFlowControlWindow);
479     assertEquals(flowControlWindow, after.localFlowControlWindow);
480   }
481 
482   @Test
transportTracer_windowUpdate_local()483   public void transportTracer_windowUpdate_local() throws Exception {
484     manualSetUp();
485     TransportStats before = transportTracer.getStats();
486     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, before.remoteFlowControlWindow);
487     assertEquals(flowControlWindow, before.localFlowControlWindow);
488 
489     // If the window size is below a certain threshold, netty will wait to apply the update.
490     // Use a large increment to be sure that it exceeds the threshold.
491     connection().local().flowController().incrementWindowSize(
492         connection().connectionStream(), 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE);
493 
494     TransportStats after = transportTracer.getStats();
495     assertEquals(Http2CodecUtil.DEFAULT_WINDOW_SIZE, after.remoteFlowControlWindow);
496     assertEquals(flowControlWindow + 8 * Http2CodecUtil.DEFAULT_WINDOW_SIZE,
497         connection().local().flowController().windowSize(connection().connectionStream()));
498   }
499 }
500