• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.junit.Assert.assertArrayEquals;
21 import static org.junit.Assert.assertEquals;
22 import static org.junit.Assert.assertFalse;
23 import static org.junit.Assert.assertNull;
24 import static org.junit.Assert.assertTrue;
25 import static org.junit.Assert.fail;
26 import static org.mockito.Mockito.mock;
27 import static org.mockito.Mockito.verify;
28 import static org.mockito.Mockito.when;
29 
30 import com.google.common.collect.ImmutableList;
31 import com.google.protobuf.ByteString;
32 import io.grpc.alts.internal.Handshaker.HandshakeProtocol;
33 import io.grpc.alts.internal.Handshaker.HandshakerReq;
34 import io.grpc.alts.internal.Handshaker.Identity;
35 import io.grpc.alts.internal.Handshaker.StartClientHandshakeReq;
36 import io.grpc.alts.internal.TransportSecurityCommon.RpcProtocolVersions;
37 import java.nio.ByteBuffer;
38 import java.security.GeneralSecurityException;
39 import org.junit.Before;
40 import org.junit.Test;
41 import org.junit.runner.RunWith;
42 import org.junit.runners.JUnit4;
43 import org.mockito.ArgumentCaptor;
44 import org.mockito.Matchers;
45 
46 /** Unit tests for {@link AltsHandshakerClient}. */
47 @RunWith(JUnit4.class)
48 public class AltsHandshakerClientTest {
49   private static final int IN_BYTES_SIZE = 100;
50   private static final int BYTES_CONSUMED = 30;
51   private static final int PREFIX_POSITION = 20;
52   private static final String TEST_TARGET_NAME = "target name";
53   private static final String TEST_TARGET_SERVICE_ACCOUNT = "peer service account";
54 
55   private AltsHandshakerStub mockStub;
56   private AltsHandshakerClient handshaker;
57   private AltsClientOptions clientOptions;
58 
59   @Before
setUp()60   public void setUp() {
61     mockStub = mock(AltsHandshakerStub.class);
62     clientOptions =
63         new AltsClientOptions.Builder()
64             .setTargetName(TEST_TARGET_NAME)
65             .setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
66             .build();
67     handshaker = new AltsHandshakerClient(mockStub, clientOptions);
68   }
69 
70   @Test
startClientHandshakeFailure()71   public void startClientHandshakeFailure() throws Exception {
72     when(mockStub.send(Matchers.<HandshakerReq>any()))
73         .thenReturn(MockAltsHandshakerResp.getErrorResponse());
74 
75     try {
76       handshaker.startClientHandshake();
77       fail("Exception expected");
78     } catch (GeneralSecurityException ex) {
79       assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
80     }
81   }
82 
83   @Test
startClientHandshakeSuccess()84   public void startClientHandshakeSuccess() throws Exception {
85     when(mockStub.send(Matchers.<HandshakerReq>any()))
86         .thenReturn(MockAltsHandshakerResp.getOkResponse(0));
87 
88     ByteBuffer outFrame = handshaker.startClientHandshake();
89 
90     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
91     assertFalse(handshaker.isFinished());
92     assertNull(handshaker.getResult());
93     assertNull(handshaker.getKey());
94   }
95 
96   @Test
startClientHandshakeWithOptions()97   public void startClientHandshakeWithOptions() throws Exception {
98     when(mockStub.send(Matchers.<HandshakerReq>any()))
99         .thenReturn(MockAltsHandshakerResp.getOkResponse(0));
100 
101     ByteBuffer outFrame = handshaker.startClientHandshake();
102     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
103 
104     HandshakerReq req =
105         HandshakerReq.newBuilder()
106             .setClientStart(
107                 StartClientHandshakeReq.newBuilder()
108                     .setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
109                     .addApplicationProtocols(AltsHandshakerClient.getApplicationProtocol())
110                     .addRecordProtocols(AltsHandshakerClient.getRecordProtocol())
111                     .setTargetName(TEST_TARGET_NAME)
112                     .addTargetIdentities(
113                         Identity.newBuilder().setServiceAccount(TEST_TARGET_SERVICE_ACCOUNT))
114                     .build())
115             .build();
116     verify(mockStub).send(req);
117   }
118 
119   @Test
startServerHandshakeFailure()120   public void startServerHandshakeFailure() throws Exception {
121     when(mockStub.send(Matchers.<HandshakerReq>any()))
122         .thenReturn(MockAltsHandshakerResp.getErrorResponse());
123 
124     try {
125       ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
126       handshaker.startServerHandshake(inBytes);
127       fail("Exception expected");
128     } catch (GeneralSecurityException ex) {
129       assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
130     }
131   }
132 
133   @Test
startServerHandshakeSuccess()134   public void startServerHandshakeSuccess() throws Exception {
135     when(mockStub.send(Matchers.<HandshakerReq>any()))
136         .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
137 
138     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
139     ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
140 
141     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
142     assertFalse(handshaker.isFinished());
143     assertNull(handshaker.getResult());
144     assertNull(handshaker.getKey());
145     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
146   }
147 
148   @Test
startServerHandshakeEmptyOutFrame()149   public void startServerHandshakeEmptyOutFrame() throws Exception {
150     when(mockStub.send(Matchers.<HandshakerReq>any()))
151         .thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
152 
153     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
154     ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
155 
156     assertEquals(0, outFrame.remaining());
157     assertFalse(handshaker.isFinished());
158     assertNull(handshaker.getResult());
159     assertNull(handshaker.getKey());
160     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
161   }
162 
163   @Test
startServerHandshakeWithPrefixBuffer()164   public void startServerHandshakeWithPrefixBuffer() throws Exception {
165     when(mockStub.send(Matchers.<HandshakerReq>any()))
166         .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
167 
168     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
169     inBytes.position(PREFIX_POSITION);
170     ByteBuffer outFrame = handshaker.startServerHandshake(inBytes);
171 
172     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
173     assertFalse(handshaker.isFinished());
174     assertNull(handshaker.getResult());
175     assertNull(handshaker.getKey());
176     assertEquals(PREFIX_POSITION + BYTES_CONSUMED, inBytes.position());
177     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED - PREFIX_POSITION, inBytes.remaining());
178   }
179 
180   @Test
nextFailure()181   public void nextFailure() throws Exception {
182     when(mockStub.send(Matchers.<HandshakerReq>any()))
183         .thenReturn(MockAltsHandshakerResp.getErrorResponse());
184 
185     try {
186       ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
187       handshaker.next(inBytes);
188       fail("Exception expected");
189     } catch (GeneralSecurityException ex) {
190       assertThat(ex).hasMessageThat().contains(MockAltsHandshakerResp.getTestErrorDetails());
191     }
192   }
193 
194   @Test
nextSuccess()195   public void nextSuccess() throws Exception {
196     when(mockStub.send(Matchers.<HandshakerReq>any()))
197         .thenReturn(MockAltsHandshakerResp.getOkResponse(BYTES_CONSUMED));
198 
199     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
200     ByteBuffer outFrame = handshaker.next(inBytes);
201 
202     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
203     assertFalse(handshaker.isFinished());
204     assertNull(handshaker.getResult());
205     assertNull(handshaker.getKey());
206     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
207   }
208 
209   @Test
nextEmptyOutFrame()210   public void nextEmptyOutFrame() throws Exception {
211     when(mockStub.send(Matchers.<HandshakerReq>any()))
212         .thenReturn(MockAltsHandshakerResp.getEmptyOutFrameResponse(BYTES_CONSUMED));
213 
214     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
215     ByteBuffer outFrame = handshaker.next(inBytes);
216 
217     assertEquals(0, outFrame.remaining());
218     assertFalse(handshaker.isFinished());
219     assertNull(handshaker.getResult());
220     assertNull(handshaker.getKey());
221     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
222   }
223 
224   @Test
nextFinished()225   public void nextFinished() throws Exception {
226     when(mockStub.send(Matchers.<HandshakerReq>any()))
227         .thenReturn(MockAltsHandshakerResp.getFinishedResponse(BYTES_CONSUMED));
228 
229     ByteBuffer inBytes = ByteBuffer.allocate(IN_BYTES_SIZE);
230     ByteBuffer outFrame = handshaker.next(inBytes);
231 
232     assertEquals(ByteString.copyFrom(outFrame), MockAltsHandshakerResp.getOutFrame());
233     assertTrue(handshaker.isFinished());
234     assertArrayEquals(handshaker.getKey(), MockAltsHandshakerResp.getTestKeyData());
235     assertEquals(IN_BYTES_SIZE - BYTES_CONSUMED, inBytes.remaining());
236   }
237 
238   @Test
setRpcVersions()239   public void setRpcVersions() throws Exception {
240     when(mockStub.send(Matchers.<HandshakerReq>any()))
241         .thenReturn(MockAltsHandshakerResp.getOkResponse(0));
242 
243     RpcProtocolVersions rpcVersions =
244         RpcProtocolVersions.newBuilder()
245             .setMinRpcVersion(
246                 RpcProtocolVersions.Version.newBuilder().setMajor(3).setMinor(4).build())
247             .setMaxRpcVersion(
248                 RpcProtocolVersions.Version.newBuilder().setMajor(5).setMinor(6).build())
249             .build();
250     clientOptions =
251         new AltsClientOptions.Builder()
252             .setTargetName(TEST_TARGET_NAME)
253             .setTargetServiceAccounts(ImmutableList.of(TEST_TARGET_SERVICE_ACCOUNT))
254             .setRpcProtocolVersions(rpcVersions)
255             .build();
256     handshaker = new AltsHandshakerClient(mockStub, clientOptions);
257 
258     handshaker.startClientHandshake();
259 
260     ArgumentCaptor<HandshakerReq> reqCaptor = ArgumentCaptor.forClass(HandshakerReq.class);
261     verify(mockStub).send(reqCaptor.capture());
262     assertEquals(rpcVersions, reqCaptor.getValue().getClientStart().getRpcVersions());
263   }
264 }
265