• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static java.nio.charset.StandardCharsets.UTF_8;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertFalse;
23 import static org.junit.Assert.assertNotNull;
24 import static org.junit.Assert.assertTrue;
25 
26 import io.grpc.Attributes;
27 import io.grpc.CallCredentials;
28 import io.grpc.Grpc;
29 import io.grpc.InternalChannelz;
30 import io.grpc.SecurityLevel;
31 import io.grpc.alts.internal.Handshaker.HandshakerResult;
32 import io.grpc.alts.internal.TsiFrameProtector.Consumer;
33 import io.grpc.alts.internal.TsiPeer.Property;
34 import io.grpc.netty.GrpcHttp2ConnectionHandler;
35 import io.netty.buffer.ByteBuf;
36 import io.netty.buffer.ByteBufAllocator;
37 import io.netty.buffer.CompositeByteBuf;
38 import io.netty.buffer.Unpooled;
39 import io.netty.channel.ChannelDuplexHandler;
40 import io.netty.channel.ChannelFuture;
41 import io.netty.channel.ChannelFutureListener;
42 import io.netty.channel.ChannelHandler;
43 import io.netty.channel.ChannelHandlerContext;
44 import io.netty.channel.ChannelPromise;
45 import io.netty.channel.embedded.EmbeddedChannel;
46 import io.netty.handler.codec.http2.DefaultHttp2Connection;
47 import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
48 import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
49 import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
50 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
51 import io.netty.handler.codec.http2.Http2Connection;
52 import io.netty.handler.codec.http2.Http2ConnectionDecoder;
53 import io.netty.handler.codec.http2.Http2ConnectionEncoder;
54 import io.netty.handler.codec.http2.Http2FrameReader;
55 import io.netty.handler.codec.http2.Http2FrameWriter;
56 import io.netty.handler.codec.http2.Http2Settings;
57 import io.netty.util.ReferenceCountUtil;
58 import io.netty.util.ReferenceCounted;
59 import java.nio.ByteBuffer;
60 import java.security.GeneralSecurityException;
61 import java.util.ArrayList;
62 import java.util.Arrays;
63 import java.util.Collections;
64 import java.util.List;
65 import java.util.concurrent.Future;
66 import java.util.concurrent.LinkedBlockingQueue;
67 import java.util.concurrent.TimeUnit;
68 import java.util.concurrent.atomic.AtomicInteger;
69 import java.util.concurrent.atomic.AtomicReference;
70 import org.junit.After;
71 import org.junit.Before;
72 import org.junit.Test;
73 import org.junit.runner.RunWith;
74 import org.junit.runners.JUnit4;
75 
76 /** Tests for {@link AltsProtocolNegotiator}. */
77 @RunWith(JUnit4.class)
78 public class AltsProtocolNegotiatorTest {
79   private final CapturingGrpcHttp2ConnectionHandler grpcHandler = capturingGrpcHandler();
80 
81   private final List<ReferenceCounted> references = new ArrayList<>();
82   private final LinkedBlockingQueue<InterceptingProtector> protectors = new LinkedBlockingQueue<>();
83 
84   private EmbeddedChannel channel;
85   private Throwable caughtException;
86 
87   private volatile TsiHandshakeHandler.TsiHandshakeCompletionEvent tsiEvent;
88   private ChannelHandler handler;
89 
90   private TsiPeer mockedTsiPeer = new TsiPeer(Collections.<Property<?>>emptyList());
91   private AltsAuthContext mockedAltsContext =
92       new AltsAuthContext(
93           HandshakerResult.newBuilder()
94               .setPeerRpcVersions(RpcProtocolVersionsUtil.getRpcProtocolVersions())
95               .build());
96   private final TsiHandshaker mockHandshaker =
97       new DelegatingTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()) {
98         @Override
99         public TsiPeer extractPeer() throws GeneralSecurityException {
100           return mockedTsiPeer;
101         }
102 
103         @Override
104         public Object extractPeerObject() throws GeneralSecurityException {
105           return mockedAltsContext;
106         }
107       };
108   private final NettyTsiHandshaker serverHandshaker = new NettyTsiHandshaker(mockHandshaker);
109 
110   @Before
setup()111   public void setup() throws Exception {
112     ChannelHandler userEventHandler =
113         new ChannelDuplexHandler() {
114           @Override
115           public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
116             if (evt instanceof TsiHandshakeHandler.TsiHandshakeCompletionEvent) {
117               tsiEvent = (TsiHandshakeHandler.TsiHandshakeCompletionEvent) evt;
118             } else {
119               super.userEventTriggered(ctx, evt);
120             }
121           }
122         };
123 
124     ChannelHandler uncaughtExceptionHandler =
125         new ChannelDuplexHandler() {
126           @Override
127           public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
128             caughtException = cause;
129             super.exceptionCaught(ctx, cause);
130           }
131         };
132 
133     TsiHandshakerFactory handshakerFactory =
134         new DelegatingTsiHandshakerFactory(FakeTsiHandshaker.clientHandshakerFactory()) {
135           @Override
136           public TsiHandshaker newHandshaker() {
137             return new DelegatingTsiHandshaker(super.newHandshaker()) {
138               @Override
139               public TsiPeer extractPeer() throws GeneralSecurityException {
140                 return mockedTsiPeer;
141               }
142 
143               @Override
144               public Object extractPeerObject() throws GeneralSecurityException {
145                 return mockedAltsContext;
146               }
147             };
148           }
149         };
150     handler = AltsProtocolNegotiator.create(handshakerFactory).newHandler(grpcHandler);
151     channel = new EmbeddedChannel(uncaughtExceptionHandler, handler, userEventHandler);
152   }
153 
154   @After
teardown()155   public void teardown() throws Exception {
156     if (channel != null) {
157       @SuppressWarnings("unused") // go/futurereturn-lsc
158       Future<?> possiblyIgnoredError = channel.close();
159     }
160 
161     for (ReferenceCounted reference : references) {
162       ReferenceCountUtil.safeRelease(reference);
163     }
164   }
165 
166   @Test
handshakeShouldBeSuccessful()167   public void handshakeShouldBeSuccessful() throws Exception {
168     doHandshake();
169   }
170 
171   @Test
172   @SuppressWarnings("unchecked") // List cast
protectShouldRoundtrip()173   public void protectShouldRoundtrip() throws Exception {
174     // Write the message 1 character at a time. The message should be buffered
175     // and not interfere with the handshake.
176     final AtomicInteger writeCount = new AtomicInteger();
177     String message = "hello";
178     for (int ix = 0; ix < message.length(); ++ix) {
179       ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
180       @SuppressWarnings("unused") // go/futurereturn-lsc
181       Future<?> possiblyIgnoredError =
182           channel
183               .write(in)
184               .addListener(
185                   new ChannelFutureListener() {
186                     @Override
187                     public void operationComplete(ChannelFuture future) throws Exception {
188                       if (future.isSuccess()) {
189                         writeCount.incrementAndGet();
190                       }
191                     }
192                   });
193     }
194     channel.flush();
195 
196     // Now do the handshake. The buffered message will automatically be protected
197     // and sent.
198     doHandshake();
199 
200     // Capture the protected data written to the wire.
201     assertEquals(1, channel.outboundMessages().size());
202     ByteBuf protectedData = channel.<ByteBuf>readOutbound();
203     assertEquals(message.length(), writeCount.get());
204 
205     // Read the protected message at the server and verify it matches the original message.
206     TsiFrameProtector serverProtector = serverHandshaker.createFrameProtector(channel.alloc());
207     List<ByteBuf> unprotected = new ArrayList<>();
208     serverProtector.unprotect(protectedData, (List<Object>) (List<?>) unprotected, channel.alloc());
209     // We try our best to remove the HTTP2 handler as soon as possible, but just by constructing it
210     // a settings frame is written (and an HTTP2 preface).  This is hard coded into Netty, so we
211     // have to remove it here.  See {@code Http2ConnectionHandler.PrefaceDecode.sendPreface}.
212     int settingsFrameLength = 9;
213 
214     CompositeByteBuf unprotectedAll =
215         new CompositeByteBuf(channel.alloc(), false, unprotected.size() + 1, unprotected);
216     ByteBuf unprotectedData = unprotectedAll.slice(settingsFrameLength, message.length());
217     assertEquals(message, unprotectedData.toString(UTF_8));
218 
219     // Protect the same message at the server.
220     final AtomicReference<ByteBuf> newlyProtectedData = new AtomicReference<>();
221     serverProtector.protectFlush(
222         Collections.singletonList(unprotectedData),
223         new Consumer<ByteBuf>() {
224           @Override
225           public void accept(ByteBuf buf) {
226             newlyProtectedData.set(buf);
227           }
228         },
229         channel.alloc());
230 
231     // Read the protected message at the client and verify that it matches the original message.
232     channel.writeInbound(newlyProtectedData.get());
233     assertEquals(1, channel.inboundMessages().size());
234     assertEquals(message, channel.<ByteBuf>readInbound().toString(UTF_8));
235   }
236 
237   @Test
unprotectLargeIncomingFrame()238   public void unprotectLargeIncomingFrame() throws Exception {
239 
240     // We use a server frameprotector with twice the standard frame size.
241     int serverFrameSize = 4096 * 2;
242     // This should fit into one frame.
243     byte[] unprotectedBytes = new byte[serverFrameSize - 500];
244     Arrays.fill(unprotectedBytes, (byte) 7);
245     ByteBuf unprotectedData = Unpooled.wrappedBuffer(unprotectedBytes);
246     unprotectedData.writerIndex(unprotectedBytes.length);
247 
248     // Perform handshake.
249     doHandshake();
250 
251     // Protect the message on the server.
252     TsiFrameProtector serverProtector =
253         serverHandshaker.createFrameProtector(serverFrameSize, channel.alloc());
254     serverProtector.protectFlush(
255         Collections.singletonList(unprotectedData),
256         new Consumer<ByteBuf>() {
257           @Override
258           public void accept(ByteBuf buf) {
259             channel.writeInbound(buf);
260           }
261         },
262         channel.alloc());
263     channel.flushInbound();
264 
265     // Read the protected message at the client and verify that it matches the original message.
266     assertEquals(1, channel.inboundMessages().size());
267 
268     ByteBuf receivedData1 = channel.<ByteBuf>readInbound();
269     int receivedLen1 = receivedData1.readableBytes();
270     byte[] receivedBytes = new byte[receivedLen1];
271     receivedData1.readBytes(receivedBytes, 0, receivedLen1);
272 
273     assertThat(unprotectedBytes.length).isEqualTo(receivedBytes.length);
274     assertThat(unprotectedBytes).isEqualTo(receivedBytes);
275   }
276 
277   @Test
flushShouldFailAllPromises()278   public void flushShouldFailAllPromises() throws Exception {
279     doHandshake();
280 
281     channel
282         .pipeline()
283         .addFirst(
284             new ChannelDuplexHandler() {
285               @Override
286               public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
287                   throws Exception {
288                 throw new Exception("Fake exception");
289               }
290             });
291 
292     // Write the message 1 character at a time.
293     String message = "hello";
294     final AtomicInteger failures = new AtomicInteger();
295     for (int ix = 0; ix < message.length(); ++ix) {
296       ByteBuf in = Unpooled.copiedBuffer(message, ix, 1, UTF_8);
297       @SuppressWarnings("unused") // go/futurereturn-lsc
298       Future<?> possiblyIgnoredError =
299           channel
300               .write(in)
301               .addListener(
302                   new ChannelFutureListener() {
303                     @Override
304                     public void operationComplete(ChannelFuture future) throws Exception {
305                       if (!future.isSuccess()) {
306                         failures.incrementAndGet();
307                       }
308                     }
309                   });
310     }
311     channel.flush();
312 
313     // Verify that the promises fail.
314     assertEquals(message.length(), failures.get());
315   }
316 
317   @Test
doNotFlushEmptyBuffer()318   public void doNotFlushEmptyBuffer() throws Exception {
319     doHandshake();
320     assertEquals(1, protectors.size());
321     InterceptingProtector protector = protectors.poll();
322 
323     String message = "hello";
324     ByteBuf in = Unpooled.copiedBuffer(message, UTF_8);
325 
326     assertEquals(0, protector.flushes.get());
327     Future<?> done = channel.write(in);
328     channel.flush();
329     done.get(5, TimeUnit.SECONDS);
330     assertEquals(1, protector.flushes.get());
331 
332     done = channel.write(Unpooled.EMPTY_BUFFER);
333     channel.flush();
334     done.get(5, TimeUnit.SECONDS);
335     assertEquals(1, protector.flushes.get());
336   }
337 
338   @Test
peerPropagated()339   public void peerPropagated() throws Exception {
340     doHandshake();
341 
342     assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getTsiPeerAttributeKey()))
343         .isEqualTo(mockedTsiPeer);
344     assertThat(grpcHandler.attrs.get(AltsProtocolNegotiator.getAltsAuthContextAttributeKey()))
345         .isEqualTo(mockedAltsContext);
346     assertThat(grpcHandler.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR).toString())
347         .isEqualTo("embedded");
348     assertThat(grpcHandler.attrs.get(CallCredentials.ATTR_SECURITY_LEVEL))
349         .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
350   }
351 
doHandshake()352   private void doHandshake() throws Exception {
353     // Capture the client frame and add to the server.
354     assertEquals(1, channel.outboundMessages().size());
355     ByteBuf clientFrame = channel.<ByteBuf>readOutbound();
356     assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
357 
358     // Get the server response handshake frames.
359     ByteBuf serverFrame = channel.alloc().buffer();
360     serverHandshaker.getBytesToSendToPeer(serverFrame);
361     channel.writeInbound(serverFrame);
362 
363     // Capture the next client frame and add to the server.
364     assertEquals(1, channel.outboundMessages().size());
365     clientFrame = channel.<ByteBuf>readOutbound();
366     assertTrue(serverHandshaker.processBytesFromPeer(clientFrame));
367 
368     // Get the server response handshake frames.
369     serverFrame = channel.alloc().buffer();
370     serverHandshaker.getBytesToSendToPeer(serverFrame);
371     channel.writeInbound(serverFrame);
372 
373     // Ensure that both sides have confirmed that the handshake has completed.
374     assertFalse(serverHandshaker.isInProgress());
375 
376     if (caughtException != null) {
377       throw new RuntimeException(caughtException);
378     }
379     assertNotNull(tsiEvent);
380   }
381 
capturingGrpcHandler()382   private CapturingGrpcHttp2ConnectionHandler capturingGrpcHandler() {
383     // Netty Boilerplate.  We don't really need any of this, but there is a tight coupling
384     // between a Http2ConnectionHandler and its dependencies.
385     Http2Connection connection = new DefaultHttp2Connection(true);
386     Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter();
387     Http2FrameReader frameReader = new DefaultHttp2FrameReader(false);
388     DefaultHttp2ConnectionEncoder encoder =
389         new DefaultHttp2ConnectionEncoder(connection, frameWriter);
390     DefaultHttp2ConnectionDecoder decoder =
391         new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader);
392 
393     return new CapturingGrpcHttp2ConnectionHandler(decoder, encoder, new Http2Settings());
394   }
395 
396   private final class CapturingGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
397     private Attributes attrs;
398 
CapturingGrpcHttp2ConnectionHandler( Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings)399     private CapturingGrpcHttp2ConnectionHandler(
400         Http2ConnectionDecoder decoder,
401         Http2ConnectionEncoder encoder,
402         Http2Settings initialSettings) {
403       super(null, decoder, encoder, initialSettings);
404     }
405 
406     @Override
handleProtocolNegotiationCompleted( Attributes attrs, InternalChannelz.Security securityInfo)407     public void handleProtocolNegotiationCompleted(
408         Attributes attrs, InternalChannelz.Security securityInfo) {
409       // If we are added to the pipeline, we need to remove ourselves.  The HTTP2 handler
410       channel.pipeline().remove(this);
411       this.attrs = attrs;
412     }
413   }
414 
415   private static class DelegatingTsiHandshakerFactory implements TsiHandshakerFactory {
416 
417     private TsiHandshakerFactory delegate;
418 
DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate)419     DelegatingTsiHandshakerFactory(TsiHandshakerFactory delegate) {
420       this.delegate = delegate;
421     }
422 
423     @Override
newHandshaker()424     public TsiHandshaker newHandshaker() {
425       return delegate.newHandshaker();
426     }
427   }
428 
429   private class DelegatingTsiHandshaker implements TsiHandshaker {
430 
431     private final TsiHandshaker delegate;
432 
DelegatingTsiHandshaker(TsiHandshaker delegate)433     DelegatingTsiHandshaker(TsiHandshaker delegate) {
434       this.delegate = delegate;
435     }
436 
437     @Override
getBytesToSendToPeer(ByteBuffer bytes)438     public void getBytesToSendToPeer(ByteBuffer bytes) throws GeneralSecurityException {
439       delegate.getBytesToSendToPeer(bytes);
440     }
441 
442     @Override
processBytesFromPeer(ByteBuffer bytes)443     public boolean processBytesFromPeer(ByteBuffer bytes) throws GeneralSecurityException {
444       return delegate.processBytesFromPeer(bytes);
445     }
446 
447     @Override
isInProgress()448     public boolean isInProgress() {
449       return delegate.isInProgress();
450     }
451 
452     @Override
extractPeer()453     public TsiPeer extractPeer() throws GeneralSecurityException {
454       return delegate.extractPeer();
455     }
456 
457     @Override
extractPeerObject()458     public Object extractPeerObject() throws GeneralSecurityException {
459       return delegate.extractPeerObject();
460     }
461 
462     @Override
createFrameProtector(ByteBufAllocator alloc)463     public TsiFrameProtector createFrameProtector(ByteBufAllocator alloc) {
464       InterceptingProtector protector =
465           new InterceptingProtector(delegate.createFrameProtector(alloc));
466       protectors.add(protector);
467       return protector;
468     }
469 
470     @Override
createFrameProtector(int maxFrameSize, ByteBufAllocator alloc)471     public TsiFrameProtector createFrameProtector(int maxFrameSize, ByteBufAllocator alloc) {
472       InterceptingProtector protector =
473           new InterceptingProtector(delegate.createFrameProtector(maxFrameSize, alloc));
474       protectors.add(protector);
475       return protector;
476     }
477   }
478 
479   private static class InterceptingProtector implements TsiFrameProtector {
480     private final TsiFrameProtector delegate;
481     final AtomicInteger flushes = new AtomicInteger();
482 
InterceptingProtector(TsiFrameProtector delegate)483     InterceptingProtector(TsiFrameProtector delegate) {
484       this.delegate = delegate;
485     }
486 
487     @Override
protectFlush( List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)488     public void protectFlush(
489         List<ByteBuf> unprotectedBufs, Consumer<ByteBuf> ctxWrite, ByteBufAllocator alloc)
490         throws GeneralSecurityException {
491       flushes.incrementAndGet();
492       delegate.protectFlush(unprotectedBufs, ctxWrite, alloc);
493     }
494 
495     @Override
unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)496     public void unprotect(ByteBuf in, List<Object> out, ByteBufAllocator alloc)
497         throws GeneralSecurityException {
498       delegate.unprotect(in, out, alloc);
499     }
500 
501     @Override
destroy()502     public void destroy() {
503       delegate.destroy();
504     }
505   }
506 }
507