• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.junit.Assert.assertFalse;
21 import static org.junit.Assert.assertNotNull;
22 import static org.junit.Assert.assertTrue;
23 import static org.junit.Assert.fail;
24 
25 import io.netty.buffer.ByteBuf;
26 import io.netty.buffer.ByteBufAllocator;
27 import io.netty.buffer.UnpooledByteBufAllocator;
28 import io.netty.util.ReferenceCounted;
29 import java.lang.reflect.Method;
30 import java.nio.ByteBuffer;
31 import java.security.GeneralSecurityException;
32 import java.util.ArrayList;
33 import java.util.List;
34 import org.junit.After;
35 import org.junit.Test;
36 import org.junit.runner.RunWith;
37 import org.junit.runners.JUnit4;
38 
39 @RunWith(JUnit4.class)
40 public class NettyTsiHandshakerTest {
41   private final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
42   private final List<ReferenceCounted> references = new ArrayList<>();
43 
44   private final NettyTsiHandshaker clientHandshaker =
45       new NettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerClient());
46   private final NettyTsiHandshaker serverHandshaker =
47       new NettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer());
48 
49   @After
teardown()50   public void teardown() {
51     for (ReferenceCounted reference : references) {
52       reference.release(reference.refCnt());
53     }
54   }
55 
56   @Test
failsOnNullHandshaker()57   public void failsOnNullHandshaker() {
58     try {
59       new NettyTsiHandshaker(null);
60       fail("Exception expected");
61     } catch (NullPointerException ex) {
62       // Do nothing.
63     }
64   }
65 
66   @Test
processPeerHandshakeShouldAcceptPartialFrames()67   public void processPeerHandshakeShouldAcceptPartialFrames() throws GeneralSecurityException {
68     for (int i = 0; i < 1024; i++) {
69       ByteBuf clientData = ref(alloc.buffer(1));
70       clientHandshaker.getBytesToSendToPeer(clientData);
71       if (clientData.isReadable()) {
72         if (serverHandshaker.processBytesFromPeer(clientData)) {
73           // Done.
74           return;
75         }
76       }
77     }
78     fail("Failed to process the handshake frame.");
79   }
80 
81   @Test
handshakeShouldSucceed()82   public void handshakeShouldSucceed() throws GeneralSecurityException {
83     doHandshake();
84   }
85 
86   @Test
isInProgress()87   public void isInProgress() throws GeneralSecurityException {
88     assertTrue(clientHandshaker.isInProgress());
89     assertTrue(serverHandshaker.isInProgress());
90 
91     doHandshake();
92 
93     assertFalse(clientHandshaker.isInProgress());
94     assertFalse(serverHandshaker.isInProgress());
95   }
96 
97   @Test
extractPeer_notNull()98   public void extractPeer_notNull() throws GeneralSecurityException {
99     doHandshake();
100 
101     assertNotNull(serverHandshaker.extractPeer());
102     assertNotNull(clientHandshaker.extractPeer());
103   }
104 
105   @Test
extractPeer_failsBeforeHandshake()106   public void extractPeer_failsBeforeHandshake() throws GeneralSecurityException {
107     try {
108       clientHandshaker.extractPeer();
109       fail("Exception expected");
110     } catch (IllegalStateException ex) {
111       // Do nothing.
112     }
113   }
114 
115   @Test
extractPeerObject_notNull()116   public void extractPeerObject_notNull() throws GeneralSecurityException {
117     doHandshake();
118 
119     assertNotNull(serverHandshaker.extractPeerObject());
120     assertNotNull(clientHandshaker.extractPeerObject());
121   }
122 
123   @Test
extractPeerObject_failsBeforeHandshake()124   public void extractPeerObject_failsBeforeHandshake() throws GeneralSecurityException {
125     try {
126       clientHandshaker.extractPeerObject();
127       fail("Exception expected");
128     } catch (IllegalStateException ex) {
129       // Do nothing.
130     }
131   }
132 
133   /**
134    * NettyTsiHandshaker just converts {@link ByteBuffer} to {@link ByteBuf}, so check that the other
135    * methods are otherwise the same.
136    */
137   @Test
handshakerMethodsMatch()138   public void handshakerMethodsMatch() {
139     List<String> expectedMethods = new ArrayList<>();
140     for (Method m : TsiHandshaker.class.getDeclaredMethods()) {
141       expectedMethods.add(m.getName());
142     }
143 
144     List<String> actualMethods = new ArrayList<>();
145     for (Method m : NettyTsiHandshaker.class.getDeclaredMethods()) {
146       actualMethods.add(m.getName());
147     }
148 
149     assertThat(actualMethods).containsAllIn(expectedMethods);
150   }
151 
doHandshake( NettyTsiHandshaker clientHandshaker, NettyTsiHandshaker serverHandshaker, ByteBufAllocator alloc, Function<ByteBuf, ByteBuf> ref)152   static void doHandshake(
153       NettyTsiHandshaker clientHandshaker,
154       NettyTsiHandshaker serverHandshaker,
155       ByteBufAllocator alloc,
156       Function<ByteBuf, ByteBuf> ref)
157       throws GeneralSecurityException {
158     // Get the server response handshake frames.
159     for (int i = 0; i < 10; i++) {
160       if (!(clientHandshaker.isInProgress() || serverHandshaker.isInProgress())) {
161         return;
162       }
163 
164       ByteBuf clientData = ref.apply(alloc.buffer());
165       clientHandshaker.getBytesToSendToPeer(clientData);
166       if (clientData.isReadable()) {
167         serverHandshaker.processBytesFromPeer(clientData);
168       }
169 
170       ByteBuf serverData = ref.apply(alloc.buffer());
171       serverHandshaker.getBytesToSendToPeer(serverData);
172       if (serverData.isReadable()) {
173         clientHandshaker.processBytesFromPeer(serverData);
174       }
175     }
176 
177     throw new AssertionError("Failed to complete the handshake.");
178   }
179 
doHandshake()180   private void doHandshake() throws GeneralSecurityException {
181     doHandshake(
182         clientHandshaker,
183         serverHandshaker,
184         alloc,
185         new Function<ByteBuf, ByteBuf>() {
186           @Override
187           public ByteBuf apply(ByteBuf buf) {
188             return ref(buf);
189           }
190         });
191   }
192 
ref(ByteBuf buf)193   private ByteBuf ref(ByteBuf buf) {
194     if (buf != null) {
195       references.add(buf);
196     }
197     return buf;
198   }
199 
200   /** A mirror of java.util.function.Function without the Java 8 dependency. */
201   private interface Function<T, R> {
apply(T t)202     R apply(T t);
203   }
204 }
205