1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.junit.Assert.assertFalse; 21 import static org.junit.Assert.assertNotNull; 22 import static org.junit.Assert.assertTrue; 23 import static org.junit.Assert.fail; 24 25 import io.netty.buffer.ByteBuf; 26 import io.netty.buffer.ByteBufAllocator; 27 import io.netty.buffer.UnpooledByteBufAllocator; 28 import io.netty.util.ReferenceCounted; 29 import java.lang.reflect.Method; 30 import java.nio.ByteBuffer; 31 import java.security.GeneralSecurityException; 32 import java.util.ArrayList; 33 import java.util.List; 34 import org.junit.After; 35 import org.junit.Test; 36 import org.junit.runner.RunWith; 37 import org.junit.runners.JUnit4; 38 39 @RunWith(JUnit4.class) 40 public class NettyTsiHandshakerTest { 41 private final UnpooledByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT; 42 private final List<ReferenceCounted> references = new ArrayList<>(); 43 44 private final NettyTsiHandshaker clientHandshaker = 45 new NettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerClient()); 46 private final NettyTsiHandshaker serverHandshaker = 47 new NettyTsiHandshaker(FakeTsiHandshaker.newFakeHandshakerServer()); 48 49 @After teardown()50 public void teardown() { 51 for (ReferenceCounted reference : references) { 52 reference.release(reference.refCnt()); 53 } 54 } 55 56 @Test failsOnNullHandshaker()57 public void failsOnNullHandshaker() { 58 try { 59 new NettyTsiHandshaker(null); 60 fail("Exception expected"); 61 } catch (NullPointerException ex) { 62 // Do nothing. 63 } 64 } 65 66 @Test processPeerHandshakeShouldAcceptPartialFrames()67 public void processPeerHandshakeShouldAcceptPartialFrames() throws GeneralSecurityException { 68 for (int i = 0; i < 1024; i++) { 69 ByteBuf clientData = ref(alloc.buffer(1)); 70 clientHandshaker.getBytesToSendToPeer(clientData); 71 if (clientData.isReadable()) { 72 if (serverHandshaker.processBytesFromPeer(clientData)) { 73 // Done. 74 return; 75 } 76 } 77 } 78 fail("Failed to process the handshake frame."); 79 } 80 81 @Test handshakeShouldSucceed()82 public void handshakeShouldSucceed() throws GeneralSecurityException { 83 doHandshake(); 84 } 85 86 @Test isInProgress()87 public void isInProgress() throws GeneralSecurityException { 88 assertTrue(clientHandshaker.isInProgress()); 89 assertTrue(serverHandshaker.isInProgress()); 90 91 doHandshake(); 92 93 assertFalse(clientHandshaker.isInProgress()); 94 assertFalse(serverHandshaker.isInProgress()); 95 } 96 97 @Test extractPeer_notNull()98 public void extractPeer_notNull() throws GeneralSecurityException { 99 doHandshake(); 100 101 assertNotNull(serverHandshaker.extractPeer()); 102 assertNotNull(clientHandshaker.extractPeer()); 103 } 104 105 @Test extractPeer_failsBeforeHandshake()106 public void extractPeer_failsBeforeHandshake() throws GeneralSecurityException { 107 try { 108 clientHandshaker.extractPeer(); 109 fail("Exception expected"); 110 } catch (IllegalStateException ex) { 111 // Do nothing. 112 } 113 } 114 115 @Test extractPeerObject_notNull()116 public void extractPeerObject_notNull() throws GeneralSecurityException { 117 doHandshake(); 118 119 assertNotNull(serverHandshaker.extractPeerObject()); 120 assertNotNull(clientHandshaker.extractPeerObject()); 121 } 122 123 @Test extractPeerObject_failsBeforeHandshake()124 public void extractPeerObject_failsBeforeHandshake() throws GeneralSecurityException { 125 try { 126 clientHandshaker.extractPeerObject(); 127 fail("Exception expected"); 128 } catch (IllegalStateException ex) { 129 // Do nothing. 130 } 131 } 132 133 /** 134 * NettyTsiHandshaker just converts {@link ByteBuffer} to {@link ByteBuf}, so check that the other 135 * methods are otherwise the same. 136 */ 137 @Test handshakerMethodsMatch()138 public void handshakerMethodsMatch() { 139 List<String> expectedMethods = new ArrayList<>(); 140 for (Method m : TsiHandshaker.class.getDeclaredMethods()) { 141 expectedMethods.add(m.getName()); 142 } 143 144 List<String> actualMethods = new ArrayList<>(); 145 for (Method m : NettyTsiHandshaker.class.getDeclaredMethods()) { 146 actualMethods.add(m.getName()); 147 } 148 149 assertThat(actualMethods).containsAllIn(expectedMethods); 150 } 151 doHandshake( NettyTsiHandshaker clientHandshaker, NettyTsiHandshaker serverHandshaker, ByteBufAllocator alloc, Function<ByteBuf, ByteBuf> ref)152 static void doHandshake( 153 NettyTsiHandshaker clientHandshaker, 154 NettyTsiHandshaker serverHandshaker, 155 ByteBufAllocator alloc, 156 Function<ByteBuf, ByteBuf> ref) 157 throws GeneralSecurityException { 158 // Get the server response handshake frames. 159 for (int i = 0; i < 10; i++) { 160 if (!(clientHandshaker.isInProgress() || serverHandshaker.isInProgress())) { 161 return; 162 } 163 164 ByteBuf clientData = ref.apply(alloc.buffer()); 165 clientHandshaker.getBytesToSendToPeer(clientData); 166 if (clientData.isReadable()) { 167 serverHandshaker.processBytesFromPeer(clientData); 168 } 169 170 ByteBuf serverData = ref.apply(alloc.buffer()); 171 serverHandshaker.getBytesToSendToPeer(serverData); 172 if (serverData.isReadable()) { 173 clientHandshaker.processBytesFromPeer(serverData); 174 } 175 } 176 177 throw new AssertionError("Failed to complete the handshake."); 178 } 179 doHandshake()180 private void doHandshake() throws GeneralSecurityException { 181 doHandshake( 182 clientHandshaker, 183 serverHandshaker, 184 alloc, 185 new Function<ByteBuf, ByteBuf>() { 186 @Override 187 public ByteBuf apply(ByteBuf buf) { 188 return ref(buf); 189 } 190 }); 191 } 192 ref(ByteBuf buf)193 private ByteBuf ref(ByteBuf buf) { 194 if (buf != null) { 195 references.add(buf); 196 } 197 return buf; 198 } 199 200 /** A mirror of java.util.function.Function without the Java 8 dependency. */ 201 private interface Function<T, R> { apply(T t)202 R apply(T t); 203 } 204 } 205