1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertFalse; 21 import static org.junit.Assert.assertTrue; 22 import static org.junit.Assert.fail; 23 import static org.mockito.Mockito.mock; 24 import static org.mockito.Mockito.never; 25 import static org.mockito.Mockito.verify; 26 import static org.mockito.Mockito.when; 27 28 import com.google.protobuf.ByteString; 29 import io.grpc.alts.internal.Handshaker.HandshakerResult; 30 import io.grpc.alts.internal.Handshaker.Identity; 31 import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions; 32 import java.nio.ByteBuffer; 33 import org.junit.Before; 34 import org.junit.Test; 35 import org.junit.runner.RunWith; 36 import org.junit.runners.JUnit4; 37 import org.mockito.Matchers; 38 39 /** Unit tests for {@link AltsTsiHandshaker}. */ 40 @RunWith(JUnit4.class) 41 public class AltsTsiHandshakerTest { 42 private static final String TEST_KEY_DATA = "super secret 123"; 43 private static final String TEST_APPLICATION_PROTOCOL = "grpc"; 44 private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128"; 45 private static final String TEST_CLIENT_SERVICE_ACCOUNT = "client@developer.gserviceaccount.com"; 46 private static final String TEST_SERVER_SERVICE_ACCOUNT = "server@developer.gserviceaccount.com"; 47 private static final int OUT_FRAME_SIZE = 100; 48 private static final int TRANSPORT_BUFFER_SIZE = 200; 49 private static final int TEST_MAX_RPC_VERSION_MAJOR = 3; 50 private static final int TEST_MAX_RPC_VERSION_MINOR = 2; 51 private static final int TEST_MIN_RPC_VERSION_MAJOR = 2; 52 private static final int TEST_MIN_RPC_VERSION_MINOR = 1; 53 private static final RpcProtocolVersions TEST_RPC_PROTOCOL_VERSIONS = 54 RpcProtocolVersions.newBuilder() 55 .setMaxRpcVersion( 56 RpcProtocolVersions.Version.newBuilder() 57 .setMajor(TEST_MAX_RPC_VERSION_MAJOR) 58 .setMinor(TEST_MAX_RPC_VERSION_MINOR) 59 .build()) 60 .setMinRpcVersion( 61 RpcProtocolVersions.Version.newBuilder() 62 .setMajor(TEST_MIN_RPC_VERSION_MAJOR) 63 .setMinor(TEST_MIN_RPC_VERSION_MINOR) 64 .build()) 65 .build(); 66 67 private AltsHandshakerClient mockClient; 68 private AltsHandshakerClient mockServer; 69 private AltsTsiHandshaker handshakerClient; 70 private AltsTsiHandshaker handshakerServer; 71 72 @Before setUp()73 public void setUp() throws Exception { 74 mockClient = mock(AltsHandshakerClient.class); 75 mockServer = mock(AltsHandshakerClient.class); 76 handshakerClient = new AltsTsiHandshaker(true, mockClient); 77 handshakerServer = new AltsTsiHandshaker(false, mockServer); 78 } 79 getHandshakerResult(boolean isClient)80 private HandshakerResult getHandshakerResult(boolean isClient) { 81 HandshakerResult.Builder builder = 82 HandshakerResult.newBuilder() 83 .setApplicationProtocol(TEST_APPLICATION_PROTOCOL) 84 .setRecordProtocol(TEST_RECORD_PROTOCOL) 85 .setKeyData(ByteString.copyFromUtf8(TEST_KEY_DATA)) 86 .setPeerRpcVersions(TEST_RPC_PROTOCOL_VERSIONS); 87 if (isClient) { 88 builder.setPeerIdentity( 89 Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build()); 90 builder.setLocalIdentity( 91 Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build()); 92 } else { 93 builder.setPeerIdentity( 94 Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build()); 95 builder.setLocalIdentity( 96 Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build()); 97 } 98 return builder.build(); 99 } 100 101 @Test processBytesFromPeerFalseStart()102 public void processBytesFromPeerFalseStart() throws Exception { 103 verify(mockClient, never()).startClientHandshake(); 104 verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any()); 105 verify(mockClient, never()).next(Matchers.<ByteBuffer>any()); 106 107 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 108 assertTrue(handshakerClient.processBytesFromPeer(transportBuffer)); 109 } 110 111 @Test processBytesFromPeerStartServer()112 public void processBytesFromPeerStartServer() throws Exception { 113 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 114 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 115 verify(mockServer, never()).startClientHandshake(); 116 verify(mockServer, never()).next(Matchers.<ByteBuffer>any()); 117 // Mock transport buffer all consumed by processBytesFromPeer and there is an output frame. 118 transportBuffer.position(transportBuffer.limit()); 119 when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame); 120 when(mockServer.isFinished()).thenReturn(false); 121 122 assertTrue(handshakerServer.processBytesFromPeer(transportBuffer)); 123 } 124 125 @Test processBytesFromPeerStartServerEmptyOutput()126 public void processBytesFromPeerStartServerEmptyOutput() throws Exception { 127 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 128 ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0); 129 verify(mockServer, never()).startClientHandshake(); 130 verify(mockServer, never()).next(Matchers.<ByteBuffer>any()); 131 // Mock transport buffer all consumed by processBytesFromPeer and output frame is empty. 132 // Expect processBytesFromPeer return False, because more data are needed from the peer. 133 transportBuffer.position(transportBuffer.limit()); 134 when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame); 135 when(mockServer.isFinished()).thenReturn(false); 136 137 assertFalse(handshakerServer.processBytesFromPeer(transportBuffer)); 138 } 139 140 @Test processBytesFromPeerStartServerFinished()141 public void processBytesFromPeerStartServerFinished() throws Exception { 142 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 143 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 144 verify(mockServer, never()).startClientHandshake(); 145 verify(mockServer, never()).next(Matchers.<ByteBuffer>any()); 146 // Mock handshake complete after processBytesFromPeer. 147 when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame); 148 when(mockServer.isFinished()).thenReturn(true); 149 150 assertTrue(handshakerServer.processBytesFromPeer(transportBuffer)); 151 } 152 153 @Test processBytesFromPeerNoBytesConsumed()154 public void processBytesFromPeerNoBytesConsumed() throws Exception { 155 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 156 ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0); 157 verify(mockServer, never()).startClientHandshake(); 158 verify(mockServer, never()).next(Matchers.<ByteBuffer>any()); 159 when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame); 160 when(mockServer.isFinished()).thenReturn(false); 161 162 try { 163 assertTrue(handshakerServer.processBytesFromPeer(transportBuffer)); 164 fail("Expected IllegalStateException"); 165 } catch (IllegalStateException expected) { 166 assertEquals("Handshaker did not consume any bytes.", expected.getMessage()); 167 } 168 } 169 170 @Test processBytesFromPeerClientNext()171 public void processBytesFromPeerClientNext() throws Exception { 172 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 173 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 174 verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any()); 175 when(mockClient.startClientHandshake()).thenReturn(outputFrame); 176 when(mockClient.next(transportBuffer)).thenReturn(outputFrame); 177 when(mockClient.isFinished()).thenReturn(false); 178 179 handshakerClient.getBytesToSendToPeer(transportBuffer); 180 transportBuffer.position(transportBuffer.limit()); 181 assertFalse(handshakerClient.processBytesFromPeer(transportBuffer)); 182 } 183 184 @Test processBytesFromPeerClientNextFinished()185 public void processBytesFromPeerClientNextFinished() throws Exception { 186 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 187 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 188 verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any()); 189 when(mockClient.startClientHandshake()).thenReturn(outputFrame); 190 when(mockClient.next(transportBuffer)).thenReturn(outputFrame); 191 when(mockClient.isFinished()).thenReturn(true); 192 193 handshakerClient.getBytesToSendToPeer(transportBuffer); 194 assertTrue(handshakerClient.processBytesFromPeer(transportBuffer)); 195 } 196 197 @Test extractPeerFailure()198 public void extractPeerFailure() throws Exception { 199 when(mockClient.isFinished()).thenReturn(false); 200 201 try { 202 handshakerClient.extractPeer(); 203 fail("Expected IllegalStateException"); 204 } catch (IllegalStateException expected) { 205 assertEquals("Handshake is not complete.", expected.getMessage()); 206 } 207 } 208 209 @Test extractPeerObjectFailure()210 public void extractPeerObjectFailure() throws Exception { 211 when(mockClient.isFinished()).thenReturn(false); 212 213 try { 214 handshakerClient.extractPeerObject(); 215 fail("Expected IllegalStateException"); 216 } catch (IllegalStateException expected) { 217 assertEquals("Handshake is not complete.", expected.getMessage()); 218 } 219 } 220 221 @Test extractClientPeerSuccess()222 public void extractClientPeerSuccess() throws Exception { 223 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 224 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 225 when(mockClient.startClientHandshake()).thenReturn(outputFrame); 226 when(mockClient.isFinished()).thenReturn(true); 227 when(mockClient.getResult()).thenReturn(getHandshakerResult(/* isClient = */ true)); 228 229 handshakerClient.getBytesToSendToPeer(transportBuffer); 230 TsiPeer clientPeer = handshakerClient.extractPeer(); 231 232 assertEquals(1, clientPeer.getProperties().size()); 233 assertEquals( 234 TEST_SERVER_SERVICE_ACCOUNT, 235 clientPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue()); 236 237 AltsAuthContext clientContext = (AltsAuthContext) handshakerClient.extractPeerObject(); 238 assertEquals(TEST_APPLICATION_PROTOCOL, clientContext.getApplicationProtocol()); 239 assertEquals(TEST_RECORD_PROTOCOL, clientContext.getRecordProtocol()); 240 assertEquals(TEST_SERVER_SERVICE_ACCOUNT, clientContext.getPeerServiceAccount()); 241 assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, clientContext.getLocalServiceAccount()); 242 assertEquals(TEST_RPC_PROTOCOL_VERSIONS, clientContext.getPeerRpcVersions()); 243 } 244 245 @Test extractServerPeerSuccess()246 public void extractServerPeerSuccess() throws Exception { 247 ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE); 248 ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE); 249 when(mockServer.startServerHandshake(Matchers.<ByteBuffer>any())).thenReturn(outputFrame); 250 when(mockServer.isFinished()).thenReturn(true); 251 when(mockServer.getResult()).thenReturn(getHandshakerResult(/* isClient = */ false)); 252 253 handshakerServer.processBytesFromPeer(transportBuffer); 254 handshakerServer.getBytesToSendToPeer(transportBuffer); 255 TsiPeer serverPeer = handshakerServer.extractPeer(); 256 257 assertEquals(1, serverPeer.getProperties().size()); 258 assertEquals( 259 TEST_CLIENT_SERVICE_ACCOUNT, 260 serverPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue()); 261 262 AltsAuthContext serverContext = (AltsAuthContext) handshakerServer.extractPeerObject(); 263 assertEquals(TEST_APPLICATION_PROTOCOL, serverContext.getApplicationProtocol()); 264 assertEquals(TEST_RECORD_PROTOCOL, serverContext.getRecordProtocol()); 265 assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, serverContext.getPeerServiceAccount()); 266 assertEquals(TEST_SERVER_SERVICE_ACCOUNT, serverContext.getLocalServiceAccount()); 267 assertEquals(TEST_RPC_PROTOCOL_VERSIONS, serverContext.getPeerRpcVersions()); 268 } 269 } 270