• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.alts.internal;
18 
19 import com.google.common.annotations.VisibleForTesting;
20 import com.google.common.base.Preconditions;
21 import com.google.common.base.Strings;
22 import com.google.protobuf.ByteString;
23 import io.grpc.ChannelLogger;
24 import io.grpc.ChannelLogger.ChannelLogLevel;
25 import io.grpc.Status;
26 import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceStub;
27 import java.io.IOException;
28 import java.nio.Buffer;
29 import java.nio.ByteBuffer;
30 import java.security.GeneralSecurityException;
31 
32 /** An API for conducting handshakes via ALTS handshaker service. */
33 class AltsHandshakerClient {
34   private static final String APPLICATION_PROTOCOL = "grpc";
35   private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
36   private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
37 
38   private final AltsHandshakerStub handshakerStub;
39   private final AltsHandshakerOptions handshakerOptions;
40   private HandshakerResult result;
41   private HandshakerStatus status;
42   private final ChannelLogger logger;
43 
44   /** Starts a new handshake interacting with the handshaker service. */
AltsHandshakerClient( HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger)45   AltsHandshakerClient(
46       HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
47     handshakerStub = new AltsHandshakerStub(stub);
48     handshakerOptions = options;
49     this.logger = logger;
50   }
51 
52   @VisibleForTesting
AltsHandshakerClient( AltsHandshakerStub handshakerStub, AltsHandshakerOptions options, ChannelLogger logger)53   AltsHandshakerClient(
54       AltsHandshakerStub handshakerStub, AltsHandshakerOptions options, ChannelLogger logger) {
55     this.handshakerStub = handshakerStub;
56     handshakerOptions = options;
57     this.logger = logger;
58   }
59 
getApplicationProtocol()60   static String getApplicationProtocol() {
61     return APPLICATION_PROTOCOL;
62   }
63 
getRecordProtocol()64   static String getRecordProtocol() {
65     return RECORD_PROTOCOL;
66   }
67 
68   /** Sets the start client fields for the passed handshake request. */
setStartClientFields(HandshakerReq.Builder req)69   private void setStartClientFields(HandshakerReq.Builder req) {
70     // Sets the default values.
71     StartClientHandshakeReq.Builder startClientReq =
72         StartClientHandshakeReq.newBuilder()
73             .setHandshakeSecurityProtocol(HandshakeProtocol.ALTS)
74             .addApplicationProtocols(APPLICATION_PROTOCOL)
75             .addRecordProtocols(RECORD_PROTOCOL);
76     // Sets handshaker options.
77     if (handshakerOptions.getRpcProtocolVersions() != null) {
78       startClientReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
79     }
80     if (handshakerOptions instanceof AltsClientOptions) {
81       AltsClientOptions clientOptions = (AltsClientOptions) handshakerOptions;
82       if (!Strings.isNullOrEmpty(clientOptions.getTargetName())) {
83         startClientReq.setTargetName(clientOptions.getTargetName());
84       }
85       for (String serviceAccount : clientOptions.getTargetServiceAccounts()) {
86         startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount);
87       }
88     }
89     startClientReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
90     req.setClientStart(startClientReq);
91   }
92 
93   /** Sets the start server fields for the passed handshake request. */
setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes)94   private void setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes) {
95     ServerHandshakeParameters serverParameters =
96         ServerHandshakeParameters.newBuilder().addRecordProtocols(RECORD_PROTOCOL).build();
97     StartServerHandshakeReq.Builder startServerReq =
98         StartServerHandshakeReq.newBuilder()
99             .addApplicationProtocols(APPLICATION_PROTOCOL)
100             .putHandshakeParameters(HandshakeProtocol.ALTS.getNumber(), serverParameters)
101             .setInBytes(ByteString.copyFrom(inBytes.duplicate()));
102     if (handshakerOptions.getRpcProtocolVersions() != null) {
103       startServerReq.setRpcVersions(handshakerOptions.getRpcProtocolVersions());
104     }
105     startServerReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
106     req.setServerStart(startServerReq);
107   }
108 
109   /** Returns true if the handshake is complete. */
isFinished()110   public boolean isFinished() {
111     // If we have a HandshakeResult, we are done.
112     if (result != null) {
113       return true;
114     }
115     // If we have an error status, we are done.
116     if (status != null && status.getCode() != Status.Code.OK.value()) {
117       return true;
118     }
119     return false;
120   }
121 
122   /** Returns the handshake status. */
getStatus()123   public HandshakerStatus getStatus() {
124     return status;
125   }
126 
127   /** Returns the result data of the handshake, if the handshake is completed. */
getResult()128   public HandshakerResult getResult() {
129     return result;
130   }
131 
132   /**
133    * Returns the resulting key of the handshake, if the handshake is completed. Note that the key
134    * data returned from the handshake may be more than the key length required for the record
135    * protocol, thus we need to truncate to the right size.
136    */
getKey()137   public byte[] getKey() {
138     if (result == null) {
139       return null;
140     }
141     if (result.getKeyData().size() < KEY_LENGTH) {
142       throw new IllegalStateException("Could not get enough key data from the handshake.");
143     }
144     byte[] key = new byte[KEY_LENGTH];
145     result.getKeyData().substring(0, KEY_LENGTH).copyTo(key, 0);
146     return key;
147   }
148 
149   /**
150    * Parses a handshake response, setting the status, result, and closing the handshaker, as needed.
151    */
handleResponse(HandshakerResp resp)152   private void handleResponse(HandshakerResp resp) throws GeneralSecurityException {
153     status = resp.getStatus();
154     if (resp.hasResult()) {
155       result = resp.getResult();
156       close();
157     }
158     if (status.getCode() != Status.Code.OK.value()) {
159       String error = "Handshaker service error: " + status.getDetails();
160       logger.log(ChannelLogLevel.DEBUG, error);
161       close();
162       throw new GeneralSecurityException(error);
163     }
164   }
165 
166   /**
167    * Starts a client handshake. A GeneralSecurityException is thrown if the handshaker service is
168    * interrupted or fails. Note that isFinished() must be false before this function is called.
169    *
170    * @return the frame to give to the peer.
171    * @throws GeneralSecurityException or IllegalStateException
172    */
startClientHandshake()173   public ByteBuffer startClientHandshake() throws GeneralSecurityException {
174     Preconditions.checkState(!isFinished(), "Handshake has already finished.");
175     HandshakerReq.Builder req = HandshakerReq.newBuilder();
176     setStartClientFields(req);
177     HandshakerResp resp;
178     try {
179       logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
180       resp = handshakerStub.send(req.build());
181       logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
182     } catch (IOException | InterruptedException e) {
183       throw new GeneralSecurityException(e);
184     }
185     handleResponse(resp);
186     return resp.getOutFrames().asReadOnlyByteBuffer();
187   }
188 
189   /**
190    * Starts a server handshake. A GeneralSecurityException is thrown if the handshaker service is
191    * interrupted or fails. Note that isFinished() must be false before this function is called.
192    *
193    * @param inBytes the bytes received from the peer.
194    * @return the frame to give to the peer.
195    * @throws GeneralSecurityException or IllegalStateException
196    */
startServerHandshake(ByteBuffer inBytes)197   public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurityException {
198     Preconditions.checkState(!isFinished(), "Handshake has already finished.");
199     HandshakerReq.Builder req = HandshakerReq.newBuilder();
200     setStartServerFields(req, inBytes);
201     HandshakerResp resp;
202     try {
203       resp = handshakerStub.send(req.build());
204     } catch (IOException | InterruptedException e) {
205       throw new GeneralSecurityException(e);
206     }
207     handleResponse(resp);
208     ((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
209     return resp.getOutFrames().asReadOnlyByteBuffer();
210   }
211 
212   /**
213    * Processes the next bytes in a handshake. A GeneralSecurityException is thrown if the handshaker
214    * service is interrupted or fails. Note that isFinished() must be false before this function is
215    * called.
216    *
217    * @param inBytes the bytes received from the peer.
218    * @return the frame to give to the peer.
219    * @throws GeneralSecurityException or IllegalStateException
220    */
next(ByteBuffer inBytes)221   public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
222     Preconditions.checkState(!isFinished(), "Handshake has already finished.");
223     HandshakerReq.Builder req =
224         HandshakerReq.newBuilder()
225             .setNext(
226                 NextHandshakeMessageReq.newBuilder()
227                     .setInBytes(ByteString.copyFrom(inBytes.duplicate()))
228                     .build());
229     HandshakerResp resp;
230     try {
231       logger.log(ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
232       resp = handshakerStub.send(req.build());
233       logger.log(ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
234     } catch (IOException | InterruptedException e) {
235       throw new GeneralSecurityException(e);
236     }
237     handleResponse(resp);
238     ((Buffer) inBytes).position(inBytes.position() + resp.getBytesConsumed());
239     return resp.getOutFrames().asReadOnlyByteBuffer();
240   }
241 
242   private boolean closed = false;
243 
244   /** Closes the connection. */
close()245   public void close() {
246     if (closed) {
247       return;
248     }
249     closed = true;
250     handshakerStub.close();
251   }
252 }
253