• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertFalse;
21 import static org.junit.Assert.assertTrue;
22 import static org.junit.Assert.fail;
23 import static org.mockito.Mockito.mock;
24 import static org.mockito.Mockito.never;
25 import static org.mockito.Mockito.verify;
26 import static org.mockito.Mockito.when;
27 
28 import com.google.protobuf.ByteString;
29 import io.grpc.alts.internal.Handshaker.HandshakerResult;
30 import io.grpc.alts.internal.Handshaker.Identity;
31 import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
32 import java.nio.ByteBuffer;
33 import org.junit.Before;
34 import org.junit.Test;
35 import org.junit.runner.RunWith;
36 import org.junit.runners.JUnit4;
37 import org.mockito.Matchers;
38 
39 /** Unit tests for {@link AltsTsiHandshaker}. */
40 @RunWith(JUnit4.class)
41 public class AltsTsiHandshakerTest {
42   private static final String TEST_KEY_DATA = "super secret 123";
43   private static final String TEST_APPLICATION_PROTOCOL = "grpc";
44   private static final String TEST_RECORD_PROTOCOL = "ALTSRP_GCM_AES128";
45   private static final String TEST_CLIENT_SERVICE_ACCOUNT = "client@developer.gserviceaccount.com";
46   private static final String TEST_SERVER_SERVICE_ACCOUNT = "server@developer.gserviceaccount.com";
47   private static final int OUT_FRAME_SIZE = 100;
48   private static final int TRANSPORT_BUFFER_SIZE = 200;
49   private static final int TEST_MAX_RPC_VERSION_MAJOR = 3;
50   private static final int TEST_MAX_RPC_VERSION_MINOR = 2;
51   private static final int TEST_MIN_RPC_VERSION_MAJOR = 2;
52   private static final int TEST_MIN_RPC_VERSION_MINOR = 1;
53   private static final RpcProtocolVersions TEST_RPC_PROTOCOL_VERSIONS =
54       RpcProtocolVersions.newBuilder()
55           .setMaxRpcVersion(
56               RpcProtocolVersions.Version.newBuilder()
57                   .setMajor(TEST_MAX_RPC_VERSION_MAJOR)
58                   .setMinor(TEST_MAX_RPC_VERSION_MINOR)
59                   .build())
60           .setMinRpcVersion(
61               RpcProtocolVersions.Version.newBuilder()
62                   .setMajor(TEST_MIN_RPC_VERSION_MAJOR)
63                   .setMinor(TEST_MIN_RPC_VERSION_MINOR)
64                   .build())
65           .build();
66 
67   private AltsHandshakerClient mockClient;
68   private AltsHandshakerClient mockServer;
69   private AltsTsiHandshaker handshakerClient;
70   private AltsTsiHandshaker handshakerServer;
71 
72   @Before
setUp()73   public void setUp() throws Exception {
74     mockClient = mock(AltsHandshakerClient.class);
75     mockServer = mock(AltsHandshakerClient.class);
76     handshakerClient = new AltsTsiHandshaker(true, mockClient);
77     handshakerServer = new AltsTsiHandshaker(false, mockServer);
78   }
79 
getHandshakerResult(boolean isClient)80   private HandshakerResult getHandshakerResult(boolean isClient) {
81     HandshakerResult.Builder builder =
82         HandshakerResult.newBuilder()
83             .setApplicationProtocol(TEST_APPLICATION_PROTOCOL)
84             .setRecordProtocol(TEST_RECORD_PROTOCOL)
85             .setKeyData(ByteString.copyFromUtf8(TEST_KEY_DATA))
86             .setPeerRpcVersions(TEST_RPC_PROTOCOL_VERSIONS);
87     if (isClient) {
88       builder.setPeerIdentity(
89           Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
90       builder.setLocalIdentity(
91           Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
92     } else {
93       builder.setPeerIdentity(
94           Identity.newBuilder().setServiceAccount(TEST_CLIENT_SERVICE_ACCOUNT).build());
95       builder.setLocalIdentity(
96           Identity.newBuilder().setServiceAccount(TEST_SERVER_SERVICE_ACCOUNT).build());
97     }
98     return builder.build();
99   }
100 
101   @Test
processBytesFromPeerFalseStart()102   public void processBytesFromPeerFalseStart() throws Exception {
103     verify(mockClient, never()).startClientHandshake();
104     verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
105     verify(mockClient, never()).next(Matchers.<ByteBuffer>any());
106 
107     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
108     assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
109   }
110 
111   @Test
processBytesFromPeerStartServer()112   public void processBytesFromPeerStartServer() throws Exception {
113     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
114     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
115     verify(mockServer, never()).startClientHandshake();
116     verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
117     // Mock transport buffer all consumed by processBytesFromPeer and there is an output frame.
118     transportBuffer.position(transportBuffer.limit());
119     when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
120     when(mockServer.isFinished()).thenReturn(false);
121 
122     assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
123   }
124 
125   @Test
processBytesFromPeerStartServerEmptyOutput()126   public void processBytesFromPeerStartServerEmptyOutput() throws Exception {
127     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
128     ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
129     verify(mockServer, never()).startClientHandshake();
130     verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
131     // Mock transport buffer all consumed by processBytesFromPeer and output frame is empty.
132     // Expect processBytesFromPeer return False, because more data are needed from the peer.
133     transportBuffer.position(transportBuffer.limit());
134     when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
135     when(mockServer.isFinished()).thenReturn(false);
136 
137     assertFalse(handshakerServer.processBytesFromPeer(transportBuffer));
138   }
139 
140   @Test
processBytesFromPeerStartServerFinished()141   public void processBytesFromPeerStartServerFinished() throws Exception {
142     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
143     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
144     verify(mockServer, never()).startClientHandshake();
145     verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
146     // Mock handshake complete after processBytesFromPeer.
147     when(mockServer.startServerHandshake(transportBuffer)).thenReturn(outputFrame);
148     when(mockServer.isFinished()).thenReturn(true);
149 
150     assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
151   }
152 
153   @Test
processBytesFromPeerNoBytesConsumed()154   public void processBytesFromPeerNoBytesConsumed() throws Exception {
155     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
156     ByteBuffer emptyOutputFrame = ByteBuffer.allocate(0);
157     verify(mockServer, never()).startClientHandshake();
158     verify(mockServer, never()).next(Matchers.<ByteBuffer>any());
159     when(mockServer.startServerHandshake(transportBuffer)).thenReturn(emptyOutputFrame);
160     when(mockServer.isFinished()).thenReturn(false);
161 
162     try {
163       assertTrue(handshakerServer.processBytesFromPeer(transportBuffer));
164       fail("Expected IllegalStateException");
165     } catch (IllegalStateException expected) {
166       assertEquals("Handshaker did not consume any bytes.", expected.getMessage());
167     }
168   }
169 
170   @Test
processBytesFromPeerClientNext()171   public void processBytesFromPeerClientNext() throws Exception {
172     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
173     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
174     verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
175     when(mockClient.startClientHandshake()).thenReturn(outputFrame);
176     when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
177     when(mockClient.isFinished()).thenReturn(false);
178 
179     handshakerClient.getBytesToSendToPeer(transportBuffer);
180     transportBuffer.position(transportBuffer.limit());
181     assertFalse(handshakerClient.processBytesFromPeer(transportBuffer));
182   }
183 
184   @Test
processBytesFromPeerClientNextFinished()185   public void processBytesFromPeerClientNextFinished() throws Exception {
186     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
187     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
188     verify(mockClient, never()).startServerHandshake(Matchers.<ByteBuffer>any());
189     when(mockClient.startClientHandshake()).thenReturn(outputFrame);
190     when(mockClient.next(transportBuffer)).thenReturn(outputFrame);
191     when(mockClient.isFinished()).thenReturn(true);
192 
193     handshakerClient.getBytesToSendToPeer(transportBuffer);
194     assertTrue(handshakerClient.processBytesFromPeer(transportBuffer));
195   }
196 
197   @Test
extractPeerFailure()198   public void extractPeerFailure() throws Exception {
199     when(mockClient.isFinished()).thenReturn(false);
200 
201     try {
202       handshakerClient.extractPeer();
203       fail("Expected IllegalStateException");
204     } catch (IllegalStateException expected) {
205       assertEquals("Handshake is not complete.", expected.getMessage());
206     }
207   }
208 
209   @Test
extractPeerObjectFailure()210   public void extractPeerObjectFailure() throws Exception {
211     when(mockClient.isFinished()).thenReturn(false);
212 
213     try {
214       handshakerClient.extractPeerObject();
215       fail("Expected IllegalStateException");
216     } catch (IllegalStateException expected) {
217       assertEquals("Handshake is not complete.", expected.getMessage());
218     }
219   }
220 
221   @Test
extractClientPeerSuccess()222   public void extractClientPeerSuccess() throws Exception {
223     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
224     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
225     when(mockClient.startClientHandshake()).thenReturn(outputFrame);
226     when(mockClient.isFinished()).thenReturn(true);
227     when(mockClient.getResult()).thenReturn(getHandshakerResult(/* isClient = */ true));
228 
229     handshakerClient.getBytesToSendToPeer(transportBuffer);
230     TsiPeer clientPeer = handshakerClient.extractPeer();
231 
232     assertEquals(1, clientPeer.getProperties().size());
233     assertEquals(
234         TEST_SERVER_SERVICE_ACCOUNT,
235         clientPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
236 
237     AltsAuthContext clientContext = (AltsAuthContext) handshakerClient.extractPeerObject();
238     assertEquals(TEST_APPLICATION_PROTOCOL, clientContext.getApplicationProtocol());
239     assertEquals(TEST_RECORD_PROTOCOL, clientContext.getRecordProtocol());
240     assertEquals(TEST_SERVER_SERVICE_ACCOUNT, clientContext.getPeerServiceAccount());
241     assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, clientContext.getLocalServiceAccount());
242     assertEquals(TEST_RPC_PROTOCOL_VERSIONS, clientContext.getPeerRpcVersions());
243   }
244 
245   @Test
extractServerPeerSuccess()246   public void extractServerPeerSuccess() throws Exception {
247     ByteBuffer outputFrame = ByteBuffer.allocate(OUT_FRAME_SIZE);
248     ByteBuffer transportBuffer = ByteBuffer.allocate(TRANSPORT_BUFFER_SIZE);
249     when(mockServer.startServerHandshake(Matchers.<ByteBuffer>any())).thenReturn(outputFrame);
250     when(mockServer.isFinished()).thenReturn(true);
251     when(mockServer.getResult()).thenReturn(getHandshakerResult(/* isClient = */ false));
252 
253     handshakerServer.processBytesFromPeer(transportBuffer);
254     handshakerServer.getBytesToSendToPeer(transportBuffer);
255     TsiPeer serverPeer = handshakerServer.extractPeer();
256 
257     assertEquals(1, serverPeer.getProperties().size());
258     assertEquals(
259         TEST_CLIENT_SERVICE_ACCOUNT,
260         serverPeer.getProperty(AltsTsiHandshaker.TSI_SERVICE_ACCOUNT_PEER_PROPERTY).getValue());
261 
262     AltsAuthContext serverContext = (AltsAuthContext) handshakerServer.extractPeerObject();
263     assertEquals(TEST_APPLICATION_PROTOCOL, serverContext.getApplicationProtocol());
264     assertEquals(TEST_RECORD_PROTOCOL, serverContext.getRecordProtocol());
265     assertEquals(TEST_CLIENT_SERVICE_ACCOUNT, serverContext.getPeerServiceAccount());
266     assertEquals(TEST_SERVER_SERVICE_ACCOUNT, serverContext.getLocalServiceAccount());
267     assertEquals(TEST_RPC_PROTOCOL_VERSIONS, serverContext.getPeerRpcVersions());
268   }
269 }
270