1 /* 2 * Copyright 2018 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.alts.internal; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.junit.Assert.assertArrayEquals; 21 import static org.junit.Assert.assertEquals; 22 import static org.junit.Assert.assertFalse; 23 import static org.junit.Assert.assertNull; 24 import static org.junit.Assert.assertTrue; 25 import static org.junit.Assert.fail; 26 import static org.mockito.Mockito.mock; 27 import static org.mockito.Mockito.verify; 28 import static org.mockito.Mockito.when; 29 30 import com.google.common.collect.ImmutableList; 31 import com.google.protobuf.ByteString; 32 import io.grpc.alts.internal.Handshaker.HandshakeProtocol; 33 import io.grpc.alts.internal.Handshaker.HandshakerReq; 34 import io.grpc.alts.internal.Handshaker.Identity; 35 import io.grpc.alts.internal.Handshaker.StartClientHandshakeReq; 36 import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions; 37 import java.nio.ByteBuffer; 38 import java.security.GeneralSecurityException; 39 import org.junit.Before; 40 import org.junit.Test; 41 import org.junit.runner.RunWith; 42 import org.junit.runners.JUnit4; 43 import org.mockito.ArgumentCaptor; 44 import org.mockito.Matchers; 45 46 /** Unit tests for {@link AltsHandshakerClient}. */ 47 @RunWith(JUnit4.class) 48 public class AltsHandshakerClientTest { 49 private static final int IN_BYTES_SIZE = 100; 50 private static final int BYTES_CONSUMED = 30; 51 private static final int PREFIX_POSITION = 20; 52 private static final String TEST_TARGET_NAME = "target name"; 53 private static final String TEST_TARGET_SERVICE_ACCOUNT = "peer service account"; 54 55 private AltsHandshakerStub mockStub; 56 private AltsHandshakerClient handshaker; 57 private AltsClientOptions clientOptions; 58 59 @Before setUp()60 public void setUp() { 61 mockStub = mock(AltsHandshakerStub.class); 62 clientOptions = 63 new AltsClientOptions.Builder() 64 .setTargetName(TEST_TARGET_NAME) 65 .setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT)) 66 .build(); 67 handshaker = new AltsHandshakerClient(mockStub, clientOptions); 68 } 69 70 @Test startClientHandshakeFailure()71 public void startClientHandshakeFailure() throws Exception { 72 when(mockStub.send(Matchers.<HandshakerReq>any())) 73 .thenReturn(MockAltsHandshakerResp.getErrorResponse()); 74 75 try { 76 handshaker.startClientHandshake(); 77 fail("Exception expected"); 78 } catch (GeneralSecurityException ex) { 79 assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails()); 80 } 81 } 82 83 @Test startClientHandshakeSuccess()84 public void startClientHandshakeSuccess() throws Exception { 85 when(mockStub.send(Matchers.<HandshakerReq>any())) 86 .thenReturn(MockAltsHandshakerResp.getOkResponse(0)); 87 88 ByteBuffer outFrame = handshaker.startClientHandshake(); 89 90 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 91 assertFalse(handshaker.isFinished()); 92 assertNull(handshaker.getResult()); 93 assertNull(handshaker.getKey()); 94 } 95 96 @Test startClientHandshakeWithOptions()97 public void startClientHandshakeWithOptions() throws Exception { 98 when(mockStub.send(Matchers.<HandshakerReq>any())) 99 .thenReturn(MockAltsHandshakerResp.getOkResponse(0)); 100 101 ByteBuffer outFrame = handshaker.startClientHandshake(); 102 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 103 104 HandshakerReq req = 105 HandshakerReq.newBuilder() 106 .setClientStart( 107 StartClientHandshakeReq.newBuilder() 108 .setHandshakeSecurityProtocol(HandshakeProtocol.ALTS) 109 .addApplicationProtocols(AltsHandshakerClient.getApplicationProtocol()) 110 .addRecordProtocols(AltsHandshakerClient.getRecordProtocol()) 111 .setTargetName(TEST_TARGET_NAME) 112 .addTargetIdentities( 113 Identity.newBuilder().setServiceAccount(TEST_TARGET_SERVICE_ACCOUNT)) 114 .build()) 115 .build(); 116 verify(mockStub).send(req); 117 } 118 119 @Test startServerHandshakeFailure()120 public void startServerHandshakeFailure() throws Exception { 121 when(mockStub.send(Matchers.<HandshakerReq>any())) 122 .thenReturn(MockAltsHandshakerResp.getErrorResponse()); 123 124 try { 125 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 126 handshaker.startServerHandshake(inBytes); 127 fail("Exception expected"); 128 } catch (GeneralSecurityException ex) { 129 assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails()); 130 } 131 } 132 133 @Test startServerHandshakeSuccess()134 public void startServerHandshakeSuccess() throws Exception { 135 when(mockStub.send(Matchers.<HandshakerReq>any())) 136 .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED)); 137 138 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 139 ByteBuffer outFrame = handshaker.startServerHandshake(inBytes); 140 141 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 142 assertFalse(handshaker.isFinished()); 143 assertNull(handshaker.getResult()); 144 assertNull(handshaker.getKey()); 145 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining()); 146 } 147 148 @Test startServerHandshakeEmptyOutFrame()149 public void startServerHandshakeEmptyOutFrame() throws Exception { 150 when(mockStub.send(Matchers.<HandshakerReq>any())) 151 .thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED)); 152 153 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 154 ByteBuffer outFrame = handshaker.startServerHandshake(inBytes); 155 156 assertEquals(0, outFrame.remaining()); 157 assertFalse(handshaker.isFinished()); 158 assertNull(handshaker.getResult()); 159 assertNull(handshaker.getKey()); 160 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining()); 161 } 162 163 @Test startServerHandshakeWithPrefixBuffer()164 public void startServerHandshakeWithPrefixBuffer() throws Exception { 165 when(mockStub.send(Matchers.<HandshakerReq>any())) 166 .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED)); 167 168 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 169 inBytes.position(PREFIX_POSITION); 170 ByteBuffer outFrame = handshaker.startServerHandshake(inBytes); 171 172 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 173 assertFalse(handshaker.isFinished()); 174 assertNull(handshaker.getResult()); 175 assertNull(handshaker.getKey()); 176 assertEquals(PREFIX_POSITION + BYTES_CONSUMED, inBytes.position()); 177 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED - PREFIX_POSITION, inBytes.remaining()); 178 } 179 180 @Test nextFailure()181 public void nextFailure() throws Exception { 182 when(mockStub.send(Matchers.<HandshakerReq>any())) 183 .thenReturn(MockAltsHandshakerResp.getErrorResponse()); 184 185 try { 186 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 187 handshaker.next(inBytes); 188 fail("Exception expected"); 189 } catch (GeneralSecurityException ex) { 190 assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails()); 191 } 192 } 193 194 @Test nextSuccess()195 public void nextSuccess() throws Exception { 196 when(mockStub.send(Matchers.<HandshakerReq>any())) 197 .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED)); 198 199 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 200 ByteBuffer outFrame = handshaker.next(inBytes); 201 202 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 203 assertFalse(handshaker.isFinished()); 204 assertNull(handshaker.getResult()); 205 assertNull(handshaker.getKey()); 206 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining()); 207 } 208 209 @Test nextEmptyOutFrame()210 public void nextEmptyOutFrame() throws Exception { 211 when(mockStub.send(Matchers.<HandshakerReq>any())) 212 .thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED)); 213 214 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 215 ByteBuffer outFrame = handshaker.next(inBytes); 216 217 assertEquals(0, outFrame.remaining()); 218 assertFalse(handshaker.isFinished()); 219 assertNull(handshaker.getResult()); 220 assertNull(handshaker.getKey()); 221 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining()); 222 } 223 224 @Test nextFinished()225 public void nextFinished() throws Exception { 226 when(mockStub.send(Matchers.<HandshakerReq>any())) 227 .thenReturn(MockAltsHandshakerResp.getFinishedResponse(BYTES_CONSUMED)); 228 229 ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE); 230 ByteBuffer outFrame = handshaker.next(inBytes); 231 232 assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame()); 233 assertTrue(handshaker.isFinished()); 234 assertArrayEquals(handshaker.getKey(), MockAltsHandshakerResp.getTestKeyData()); 235 assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining()); 236 } 237 238 @Test setRpcVersions()239 public void setRpcVersions() throws Exception { 240 when(mockStub.send(Matchers.<HandshakerReq>any())) 241 .thenReturn(MockAltsHandshakerResp.getOkResponse(0)); 242 243 RpcProtocolVersions rpcVersions = 244 RpcProtocolVersions.newBuilder() 245 .setMinRpcVersion( 246 RpcProtocolVersions.Version.newBuilder().setMajor(3).setMinor(4).build()) 247 .setMaxRpcVersion( 248 RpcProtocolVersions.Version.newBuilder().setMajor(5).setMinor(6).build()) 249 .build(); 250 clientOptions = 251 new AltsClientOptions.Builder() 252 .setTargetName(TEST_TARGET_NAME) 253 .setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT)) 254 .setRpcProtocolVersions(rpcVersions) 255 .build(); 256 handshaker = new AltsHandshakerClient(mockStub, clientOptions); 257 258 handshaker.startClientHandshake(); 259 260 ArgumentCaptor<HandshakerReq> reqCaptor = ArgumentCaptor.forClass(HandshakerReq.class); 261 verify(mockStub).send(reqCaptor.capture()); 262 assertEquals(rpcVersions, reqCaptor.getValue().getClientStart().getRpcVersions()); 263 } 264 } 265