• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *
3  * Copyright 2018 gRPC authors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *
17  */
18 
19 #include <grpc/support/port_platform.h>
20 
21 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h"
22 
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 
27 #include "upb/upb.hpp"
28 
29 #include <grpc/support/alloc.h>
30 #include <grpc/support/log.h>
31 #include <grpc/support/string_util.h>
32 #include <grpc/support/sync.h>
33 #include <grpc/support/thd_id.h>
34 
35 #include "src/core/lib/gprpp/sync.h"
36 #include "src/core/lib/gprpp/thd.h"
37 #include "src/core/lib/iomgr/closure.h"
38 #include "src/core/lib/slice/slice_internal.h"
39 #include "src/core/lib/surface/channel.h"
40 #include "src/core/tsi/alts/frame_protector/alts_frame_protector.h"
41 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
42 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
43 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
44 #include "src/core/tsi/alts/zero_copy_frame_protector/alts_zero_copy_grpc_protector.h"
45 
46 /* Main struct for ALTS TSI handshaker. */
47 struct alts_tsi_handshaker {
48   tsi_handshaker base;
49   grpc_slice target_name;
50   bool is_client;
51   bool has_sent_start_message;
52   bool has_created_handshaker_client;
53   char* handshaker_service_url;
54   grpc_pollset_set* interested_parties;
55   grpc_alts_credentials_options* options;
56   alts_handshaker_client_vtable* client_vtable_for_testing;
57   grpc_channel* channel;
58   bool use_dedicated_cq;
59   // mu synchronizes all fields below. Note these are the
60   // only fields that can be concurrently accessed (due to
61   // potential concurrency of tsi_handshaker_shutdown and
62   // tsi_handshaker_next).
63   gpr_mu mu;
64   alts_handshaker_client* client;
65   // shutdown effectively follows base.handshake_shutdown,
66   // but is synchronized by the mutex of this object.
67   bool shutdown;
68   // Maximum frame size used by frame protector.
69   size_t max_frame_size;
70 };
71 
72 /* Main struct for ALTS TSI handshaker result. */
73 typedef struct alts_tsi_handshaker_result {
74   tsi_handshaker_result base;
75   char* peer_identity;
76   char* key_data;
77   unsigned char* unused_bytes;
78   size_t unused_bytes_size;
79   grpc_slice rpc_versions;
80   bool is_client;
81   grpc_slice serialized_context;
82   // Peer's maximum frame size.
83   size_t max_frame_size;
84 } alts_tsi_handshaker_result;
85 
handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)86 static tsi_result handshaker_result_extract_peer(
87     const tsi_handshaker_result* self, tsi_peer* peer) {
88   if (self == nullptr || peer == nullptr) {
89     gpr_log(GPR_ERROR, "Invalid argument to handshaker_result_extract_peer()");
90     return TSI_INVALID_ARGUMENT;
91   }
92   alts_tsi_handshaker_result* result =
93       reinterpret_cast<alts_tsi_handshaker_result*>(
94           const_cast<tsi_handshaker_result*>(self));
95   GPR_ASSERT(kTsiAltsNumOfPeerProperties == 5);
96   tsi_result ok = tsi_construct_peer(kTsiAltsNumOfPeerProperties, peer);
97   int index = 0;
98   if (ok != TSI_OK) {
99     gpr_log(GPR_ERROR, "Failed to construct tsi peer");
100     return ok;
101   }
102   GPR_ASSERT(&peer->properties[index] != nullptr);
103   ok = tsi_construct_string_peer_property_from_cstring(
104       TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_ALTS_CERTIFICATE_TYPE,
105       &peer->properties[index]);
106   if (ok != TSI_OK) {
107     tsi_peer_destruct(peer);
108     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
109     return ok;
110   }
111   index++;
112   GPR_ASSERT(&peer->properties[index] != nullptr);
113   ok = tsi_construct_string_peer_property_from_cstring(
114       TSI_ALTS_SERVICE_ACCOUNT_PEER_PROPERTY, result->peer_identity,
115       &peer->properties[index]);
116   if (ok != TSI_OK) {
117     tsi_peer_destruct(peer);
118     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
119   }
120   index++;
121   GPR_ASSERT(&peer->properties[index] != nullptr);
122   ok = tsi_construct_string_peer_property(
123       TSI_ALTS_RPC_VERSIONS,
124       reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->rpc_versions)),
125       GRPC_SLICE_LENGTH(result->rpc_versions), &peer->properties[index]);
126   if (ok != TSI_OK) {
127     tsi_peer_destruct(peer);
128     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
129   }
130   index++;
131   GPR_ASSERT(&peer->properties[index] != nullptr);
132   ok = tsi_construct_string_peer_property(
133       TSI_ALTS_CONTEXT,
134       reinterpret_cast<char*>(GRPC_SLICE_START_PTR(result->serialized_context)),
135       GRPC_SLICE_LENGTH(result->serialized_context), &peer->properties[index]);
136   if (ok != TSI_OK) {
137     tsi_peer_destruct(peer);
138     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
139   }
140   index++;
141   GPR_ASSERT(&peer->properties[index] != nullptr);
142   ok = tsi_construct_string_peer_property_from_cstring(
143       TSI_SECURITY_LEVEL_PEER_PROPERTY,
144       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
145       &peer->properties[index]);
146   if (ok != TSI_OK) {
147     tsi_peer_destruct(peer);
148     gpr_log(GPR_ERROR, "Failed to set tsi peer property");
149   }
150   GPR_ASSERT(++index == kTsiAltsNumOfPeerProperties);
151   return ok;
152 }
153 
handshaker_result_create_zero_copy_grpc_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_zero_copy_grpc_protector ** protector)154 static tsi_result handshaker_result_create_zero_copy_grpc_protector(
155     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
156     tsi_zero_copy_grpc_protector** protector) {
157   if (self == nullptr || protector == nullptr) {
158     gpr_log(GPR_ERROR,
159             "Invalid arguments to create_zero_copy_grpc_protector()");
160     return TSI_INVALID_ARGUMENT;
161   }
162   alts_tsi_handshaker_result* result =
163       reinterpret_cast<alts_tsi_handshaker_result*>(
164           const_cast<tsi_handshaker_result*>(self));
165 
166   // In case the peer does not send max frame size (e.g. peer is gRPC Go or
167   // peer uses an old binary), the negotiated frame size is set to
168   // kTsiAltsMinFrameSize (ignoring max_output_protected_frame_size value if
169   // present). Otherwise, it is based on peer and user specified max frame
170   // size (if present).
171   size_t max_frame_size = kTsiAltsMinFrameSize;
172   if (result->max_frame_size) {
173     size_t peer_max_frame_size = result->max_frame_size;
174     max_frame_size = std::min<size_t>(peer_max_frame_size,
175                                       max_output_protected_frame_size == nullptr
176                                           ? kTsiAltsMaxFrameSize
177                                           : *max_output_protected_frame_size);
178     max_frame_size = std::max<size_t>(max_frame_size, kTsiAltsMinFrameSize);
179   }
180   max_output_protected_frame_size = &max_frame_size;
181   gpr_log(GPR_DEBUG,
182           "After Frame Size Negotiation, maximum frame size used by frame "
183           "protector equals %zu",
184           *max_output_protected_frame_size);
185   tsi_result ok = alts_zero_copy_grpc_protector_create(
186       reinterpret_cast<const uint8_t*>(result->key_data),
187       kAltsAes128GcmRekeyKeyLength, /*is_rekey=*/true, result->is_client,
188       /*is_integrity_only=*/false, /*enable_extra_copy=*/false,
189       max_output_protected_frame_size, protector);
190   if (ok != TSI_OK) {
191     gpr_log(GPR_ERROR, "Failed to create zero-copy grpc protector");
192   }
193   return ok;
194 }
195 
handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)196 static tsi_result handshaker_result_create_frame_protector(
197     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
198     tsi_frame_protector** protector) {
199   if (self == nullptr || protector == nullptr) {
200     gpr_log(GPR_ERROR,
201             "Invalid arguments to handshaker_result_create_frame_protector()");
202     return TSI_INVALID_ARGUMENT;
203   }
204   alts_tsi_handshaker_result* result =
205       reinterpret_cast<alts_tsi_handshaker_result*>(
206           const_cast<tsi_handshaker_result*>(self));
207   tsi_result ok = alts_create_frame_protector(
208       reinterpret_cast<const uint8_t*>(result->key_data),
209       kAltsAes128GcmRekeyKeyLength, result->is_client, /*is_rekey=*/true,
210       max_output_protected_frame_size, protector);
211   if (ok != TSI_OK) {
212     gpr_log(GPR_ERROR, "Failed to create frame protector");
213   }
214   return ok;
215 }
216 
handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)217 static tsi_result handshaker_result_get_unused_bytes(
218     const tsi_handshaker_result* self, const unsigned char** bytes,
219     size_t* bytes_size) {
220   if (self == nullptr || bytes == nullptr || bytes_size == nullptr) {
221     gpr_log(GPR_ERROR,
222             "Invalid arguments to handshaker_result_get_unused_bytes()");
223     return TSI_INVALID_ARGUMENT;
224   }
225   alts_tsi_handshaker_result* result =
226       reinterpret_cast<alts_tsi_handshaker_result*>(
227           const_cast<tsi_handshaker_result*>(self));
228   *bytes = result->unused_bytes;
229   *bytes_size = result->unused_bytes_size;
230   return TSI_OK;
231 }
232 
handshaker_result_destroy(tsi_handshaker_result * self)233 static void handshaker_result_destroy(tsi_handshaker_result* self) {
234   if (self == nullptr) {
235     return;
236   }
237   alts_tsi_handshaker_result* result =
238       reinterpret_cast<alts_tsi_handshaker_result*>(
239           const_cast<tsi_handshaker_result*>(self));
240   gpr_free(result->peer_identity);
241   gpr_free(result->key_data);
242   gpr_free(result->unused_bytes);
243   grpc_slice_unref_internal(result->rpc_versions);
244   grpc_slice_unref_internal(result->serialized_context);
245   gpr_free(result);
246 }
247 
248 static const tsi_handshaker_result_vtable result_vtable = {
249     handshaker_result_extract_peer,
250     handshaker_result_create_zero_copy_grpc_protector,
251     handshaker_result_create_frame_protector,
252     handshaker_result_get_unused_bytes, handshaker_result_destroy};
253 
alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp * resp,bool is_client,tsi_handshaker_result ** result)254 tsi_result alts_tsi_handshaker_result_create(grpc_gcp_HandshakerResp* resp,
255                                              bool is_client,
256                                              tsi_handshaker_result** result) {
257   if (result == nullptr || resp == nullptr) {
258     gpr_log(GPR_ERROR, "Invalid arguments to create_handshaker_result()");
259     return TSI_INVALID_ARGUMENT;
260   }
261   const grpc_gcp_HandshakerResult* hresult =
262       grpc_gcp_HandshakerResp_result(resp);
263   const grpc_gcp_Identity* identity =
264       grpc_gcp_HandshakerResult_peer_identity(hresult);
265   if (identity == nullptr) {
266     gpr_log(GPR_ERROR, "Invalid identity");
267     return TSI_FAILED_PRECONDITION;
268   }
269   upb_strview peer_service_account =
270       grpc_gcp_Identity_service_account(identity);
271   if (peer_service_account.size == 0) {
272     gpr_log(GPR_ERROR, "Invalid peer service account");
273     return TSI_FAILED_PRECONDITION;
274   }
275   upb_strview key_data = grpc_gcp_HandshakerResult_key_data(hresult);
276   if (key_data.size < kAltsAes128GcmRekeyKeyLength) {
277     gpr_log(GPR_ERROR, "Bad key length");
278     return TSI_FAILED_PRECONDITION;
279   }
280   const grpc_gcp_RpcProtocolVersions* peer_rpc_version =
281       grpc_gcp_HandshakerResult_peer_rpc_versions(hresult);
282   if (peer_rpc_version == nullptr) {
283     gpr_log(GPR_ERROR, "Peer does not set RPC protocol versions.");
284     return TSI_FAILED_PRECONDITION;
285   }
286   upb_strview application_protocol =
287       grpc_gcp_HandshakerResult_application_protocol(hresult);
288   if (application_protocol.size == 0) {
289     gpr_log(GPR_ERROR, "Invalid application protocol");
290     return TSI_FAILED_PRECONDITION;
291   }
292   upb_strview record_protocol =
293       grpc_gcp_HandshakerResult_record_protocol(hresult);
294   if (record_protocol.size == 0) {
295     gpr_log(GPR_ERROR, "Invalid record protocol");
296     return TSI_FAILED_PRECONDITION;
297   }
298   const grpc_gcp_Identity* local_identity =
299       grpc_gcp_HandshakerResult_local_identity(hresult);
300   if (local_identity == nullptr) {
301     gpr_log(GPR_ERROR, "Invalid local identity");
302     return TSI_FAILED_PRECONDITION;
303   }
304   upb_strview local_service_account =
305       grpc_gcp_Identity_service_account(local_identity);
306   // We don't check if local service account is empty here
307   // because local identity could be empty in certain situations.
308   alts_tsi_handshaker_result* sresult =
309       static_cast<alts_tsi_handshaker_result*>(gpr_zalloc(sizeof(*sresult)));
310   sresult->key_data =
311       static_cast<char*>(gpr_zalloc(kAltsAes128GcmRekeyKeyLength));
312   memcpy(sresult->key_data, key_data.data, kAltsAes128GcmRekeyKeyLength);
313   sresult->peer_identity =
314       static_cast<char*>(gpr_zalloc(peer_service_account.size + 1));
315   memcpy(sresult->peer_identity, peer_service_account.data,
316          peer_service_account.size);
317   sresult->max_frame_size = grpc_gcp_HandshakerResult_max_frame_size(hresult);
318   upb::Arena rpc_versions_arena;
319   bool serialized = grpc_gcp_rpc_protocol_versions_encode(
320       peer_rpc_version, rpc_versions_arena.ptr(), &sresult->rpc_versions);
321   if (!serialized) {
322     gpr_log(GPR_ERROR, "Failed to serialize peer's RPC protocol versions.");
323     return TSI_FAILED_PRECONDITION;
324   }
325   upb::Arena context_arena;
326   grpc_gcp_AltsContext* context = grpc_gcp_AltsContext_new(context_arena.ptr());
327   grpc_gcp_AltsContext_set_application_protocol(context, application_protocol);
328   grpc_gcp_AltsContext_set_record_protocol(context, record_protocol);
329   // ALTS currently only supports the security level of 2,
330   // which is "grpc_gcp_INTEGRITY_AND_PRIVACY".
331   grpc_gcp_AltsContext_set_security_level(context, 2);
332   grpc_gcp_AltsContext_set_peer_service_account(context, peer_service_account);
333   grpc_gcp_AltsContext_set_local_service_account(context,
334                                                  local_service_account);
335   grpc_gcp_AltsContext_set_peer_rpc_versions(
336       context, const_cast<grpc_gcp_RpcProtocolVersions*>(peer_rpc_version));
337   grpc_gcp_Identity* peer_identity = const_cast<grpc_gcp_Identity*>(identity);
338   if (peer_identity == nullptr) {
339     gpr_log(GPR_ERROR, "Null peer identity in ALTS context.");
340     return TSI_FAILED_PRECONDITION;
341   }
342   if (grpc_gcp_Identity_has_attributes(identity)) {
343     size_t iter = UPB_MAP_BEGIN;
344     grpc_gcp_Identity_AttributesEntry* peer_attributes_entry =
345         grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
346     while (peer_attributes_entry != nullptr) {
347       upb_strview key = grpc_gcp_Identity_AttributesEntry_key(
348           const_cast<grpc_gcp_Identity_AttributesEntry*>(
349               peer_attributes_entry));
350       upb_strview val = grpc_gcp_Identity_AttributesEntry_value(
351           const_cast<grpc_gcp_Identity_AttributesEntry*>(
352               peer_attributes_entry));
353       grpc_gcp_AltsContext_peer_attributes_set(context, key, val,
354                                                context_arena.ptr());
355       peer_attributes_entry =
356           grpc_gcp_Identity_attributes_nextmutable(peer_identity, &iter);
357     }
358   }
359   size_t serialized_ctx_length;
360   char* serialized_ctx = grpc_gcp_AltsContext_serialize(
361       context, context_arena.ptr(), &serialized_ctx_length);
362   if (serialized_ctx == nullptr) {
363     gpr_log(GPR_ERROR, "Failed to serialize peer's ALTS context.");
364     return TSI_FAILED_PRECONDITION;
365   }
366   sresult->serialized_context =
367       grpc_slice_from_copied_buffer(serialized_ctx, serialized_ctx_length);
368   sresult->is_client = is_client;
369   sresult->base.vtable = &result_vtable;
370   *result = &sresult->base;
371   return TSI_OK;
372 }
373 
374 /* gRPC provided callback used when gRPC thread model is applied. */
on_handshaker_service_resp_recv(void * arg,grpc_error * error)375 static void on_handshaker_service_resp_recv(void* arg, grpc_error* error) {
376   alts_handshaker_client* client = static_cast<alts_handshaker_client*>(arg);
377   if (client == nullptr) {
378     gpr_log(GPR_ERROR, "ALTS handshaker client is nullptr");
379     return;
380   }
381   bool success = true;
382   if (error != GRPC_ERROR_NONE) {
383     gpr_log(GPR_ERROR,
384             "ALTS handshaker on_handshaker_service_resp_recv error: %s",
385             grpc_error_string(error));
386     success = false;
387   }
388   alts_handshaker_client_handle_response(client, success);
389 }
390 
391 /* gRPC provided callback used when dedicatd CQ and thread are used.
392  * It serves to safely bring the control back to application. */
on_handshaker_service_resp_recv_dedicated(void * arg,grpc_error *)393 static void on_handshaker_service_resp_recv_dedicated(void* arg,
394                                                       grpc_error* /*error*/) {
395   alts_shared_resource_dedicated* resource =
396       grpc_alts_get_shared_resource_dedicated();
397   grpc_cq_end_op(
398       resource->cq, arg, GRPC_ERROR_NONE,
399       [](void* /*done_arg*/, grpc_cq_completion* /*storage*/) {}, nullptr,
400       &resource->storage);
401 }
402 
403 /* Returns TSI_OK if and only if no error is encountered. */
alts_tsi_handshaker_continue_handshaker_next(alts_tsi_handshaker * handshaker,const unsigned char * received_bytes,size_t received_bytes_size,tsi_handshaker_on_next_done_cb cb,void * user_data)404 static tsi_result alts_tsi_handshaker_continue_handshaker_next(
405     alts_tsi_handshaker* handshaker, const unsigned char* received_bytes,
406     size_t received_bytes_size, tsi_handshaker_on_next_done_cb cb,
407     void* user_data) {
408   if (!handshaker->has_created_handshaker_client) {
409     if (handshaker->channel == nullptr) {
410       grpc_alts_shared_resource_dedicated_start(
411           handshaker->handshaker_service_url);
412       handshaker->interested_parties =
413           grpc_alts_get_shared_resource_dedicated()->interested_parties;
414       GPR_ASSERT(handshaker->interested_parties != nullptr);
415     }
416     grpc_iomgr_cb_func grpc_cb = handshaker->channel == nullptr
417                                      ? on_handshaker_service_resp_recv_dedicated
418                                      : on_handshaker_service_resp_recv;
419     grpc_channel* channel =
420         handshaker->channel == nullptr
421             ? grpc_alts_get_shared_resource_dedicated()->channel
422             : handshaker->channel;
423     alts_handshaker_client* client = alts_grpc_handshaker_client_create(
424         handshaker, channel, handshaker->handshaker_service_url,
425         handshaker->interested_parties, handshaker->options,
426         handshaker->target_name, grpc_cb, cb, user_data,
427         handshaker->client_vtable_for_testing, handshaker->is_client,
428         handshaker->max_frame_size);
429     if (client == nullptr) {
430       gpr_log(GPR_ERROR, "Failed to create ALTS handshaker client");
431       return TSI_FAILED_PRECONDITION;
432     }
433     {
434       grpc_core::MutexLock lock(&handshaker->mu);
435       GPR_ASSERT(handshaker->client == nullptr);
436       handshaker->client = client;
437       if (handshaker->shutdown) {
438         gpr_log(GPR_ERROR, "TSI handshake shutdown");
439         return TSI_HANDSHAKE_SHUTDOWN;
440       }
441     }
442     handshaker->has_created_handshaker_client = true;
443   }
444   if (handshaker->channel == nullptr &&
445       handshaker->client_vtable_for_testing == nullptr) {
446     GPR_ASSERT(grpc_cq_begin_op(grpc_alts_get_shared_resource_dedicated()->cq,
447                                 handshaker->client));
448   }
449   grpc_slice slice = (received_bytes == nullptr || received_bytes_size == 0)
450                          ? grpc_empty_slice()
451                          : grpc_slice_from_copied_buffer(
452                                reinterpret_cast<const char*>(received_bytes),
453                                received_bytes_size);
454   tsi_result ok = TSI_OK;
455   if (!handshaker->has_sent_start_message) {
456     handshaker->has_sent_start_message = true;
457     ok = handshaker->is_client
458              ? alts_handshaker_client_start_client(handshaker->client)
459              : alts_handshaker_client_start_server(handshaker->client, &slice);
460     // It's unsafe for the current thread to access any state in handshaker
461     // at this point, since alts_handshaker_client_start_client/server
462     // have potentially just started an op batch on the handshake call.
463     // The completion callback for that batch is unsynchronized and so
464     // can invoke the TSI next API callback from any thread, at which point
465     // there is nothing taking ownership of this handshaker to prevent it
466     // from being destroyed.
467   } else {
468     ok = alts_handshaker_client_next(handshaker->client, &slice);
469   }
470   grpc_slice_unref_internal(slice);
471   return ok;
472 }
473 
474 struct alts_tsi_handshaker_continue_handshaker_next_args {
475   alts_tsi_handshaker* handshaker;
476   std::unique_ptr<unsigned char> received_bytes;
477   size_t received_bytes_size;
478   tsi_handshaker_on_next_done_cb cb;
479   void* user_data;
480   grpc_closure closure;
481 };
482 
alts_tsi_handshaker_create_channel(void * arg,grpc_error *)483 static void alts_tsi_handshaker_create_channel(void* arg,
484                                                grpc_error* /* unused_error */) {
485   alts_tsi_handshaker_continue_handshaker_next_args* next_args =
486       static_cast<alts_tsi_handshaker_continue_handshaker_next_args*>(arg);
487   alts_tsi_handshaker* handshaker = next_args->handshaker;
488   GPR_ASSERT(handshaker->channel == nullptr);
489   handshaker->channel = grpc_insecure_channel_create(
490       next_args->handshaker->handshaker_service_url, nullptr, nullptr);
491   tsi_result continue_next_result =
492       alts_tsi_handshaker_continue_handshaker_next(
493           handshaker, next_args->received_bytes.get(),
494           next_args->received_bytes_size, next_args->cb, next_args->user_data);
495   if (continue_next_result != TSI_OK) {
496     next_args->cb(continue_next_result, next_args->user_data, nullptr, 0,
497                   nullptr);
498   }
499   delete next_args;
500 }
501 
handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char **,size_t *,tsi_handshaker_result **,tsi_handshaker_on_next_done_cb cb,void * user_data)502 static tsi_result handshaker_next(
503     tsi_handshaker* self, const unsigned char* received_bytes,
504     size_t received_bytes_size, const unsigned char** /*bytes_to_send*/,
505     size_t* /*bytes_to_send_size*/, tsi_handshaker_result** /*result*/,
506     tsi_handshaker_on_next_done_cb cb, void* user_data) {
507   if (self == nullptr || cb == nullptr) {
508     gpr_log(GPR_ERROR, "Invalid arguments to handshaker_next()");
509     return TSI_INVALID_ARGUMENT;
510   }
511   alts_tsi_handshaker* handshaker =
512       reinterpret_cast<alts_tsi_handshaker*>(self);
513   {
514     grpc_core::MutexLock lock(&handshaker->mu);
515     if (handshaker->shutdown) {
516       gpr_log(GPR_ERROR, "TSI handshake shutdown");
517       return TSI_HANDSHAKE_SHUTDOWN;
518     }
519   }
520   if (handshaker->channel == nullptr && !handshaker->use_dedicated_cq) {
521     alts_tsi_handshaker_continue_handshaker_next_args* args =
522         new alts_tsi_handshaker_continue_handshaker_next_args();
523     args->handshaker = handshaker;
524     args->received_bytes = nullptr;
525     args->received_bytes_size = received_bytes_size;
526     if (received_bytes_size > 0) {
527       args->received_bytes = std::unique_ptr<unsigned char>(
528           static_cast<unsigned char*>(gpr_zalloc(received_bytes_size)));
529       memcpy(args->received_bytes.get(), received_bytes, received_bytes_size);
530     }
531     args->cb = cb;
532     args->user_data = user_data;
533     GRPC_CLOSURE_INIT(&args->closure, alts_tsi_handshaker_create_channel, args,
534                       grpc_schedule_on_exec_ctx);
535     // We continue this handshaker_next call at the bottom of the ExecCtx just
536     // so that we can invoke grpc_channel_create at the bottom of the call
537     // stack. Doing so avoids potential lock cycles between g_init_mu and other
538     // mutexes within core that might be held on the current call stack
539     // (note that g_init_mu gets acquired during channel creation).
540     grpc_core::ExecCtx::Run(DEBUG_LOCATION, &args->closure, GRPC_ERROR_NONE);
541   } else {
542     tsi_result ok = alts_tsi_handshaker_continue_handshaker_next(
543         handshaker, received_bytes, received_bytes_size, cb, user_data);
544     if (ok != TSI_OK) {
545       gpr_log(GPR_ERROR, "Failed to schedule ALTS handshaker requests");
546       return ok;
547     }
548   }
549   return TSI_ASYNC;
550 }
551 
552 /*
553  * This API will be invoked by a non-gRPC application, and an ExecCtx needs
554  * to be explicitly created in order to invoke ALTS handshaker client API's
555  * that assumes the caller is inside gRPC core.
556  */
handshaker_next_dedicated(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** result,tsi_handshaker_on_next_done_cb cb,void * user_data)557 static tsi_result handshaker_next_dedicated(
558     tsi_handshaker* self, const unsigned char* received_bytes,
559     size_t received_bytes_size, const unsigned char** bytes_to_send,
560     size_t* bytes_to_send_size, tsi_handshaker_result** result,
561     tsi_handshaker_on_next_done_cb cb, void* user_data) {
562   grpc_core::ExecCtx exec_ctx;
563   return handshaker_next(self, received_bytes, received_bytes_size,
564                          bytes_to_send, bytes_to_send_size, result, cb,
565                          user_data);
566 }
567 
handshaker_shutdown(tsi_handshaker * self)568 static void handshaker_shutdown(tsi_handshaker* self) {
569   GPR_ASSERT(self != nullptr);
570   alts_tsi_handshaker* handshaker =
571       reinterpret_cast<alts_tsi_handshaker*>(self);
572   grpc_core::MutexLock lock(&handshaker->mu);
573   if (handshaker->shutdown) {
574     return;
575   }
576   if (handshaker->client != nullptr) {
577     alts_handshaker_client_shutdown(handshaker->client);
578   }
579   handshaker->shutdown = true;
580 }
581 
handshaker_destroy(tsi_handshaker * self)582 static void handshaker_destroy(tsi_handshaker* self) {
583   if (self == nullptr) {
584     return;
585   }
586   alts_tsi_handshaker* handshaker =
587       reinterpret_cast<alts_tsi_handshaker*>(self);
588   alts_handshaker_client_destroy(handshaker->client);
589   grpc_slice_unref_internal(handshaker->target_name);
590   grpc_alts_credentials_options_destroy(handshaker->options);
591   if (handshaker->channel != nullptr) {
592     grpc_channel_destroy_internal(handshaker->channel);
593   }
594   gpr_free(handshaker->handshaker_service_url);
595   gpr_mu_destroy(&handshaker->mu);
596   gpr_free(handshaker);
597 }
598 
599 static const tsi_handshaker_vtable handshaker_vtable = {
600     nullptr,         nullptr,
601     nullptr,         nullptr,
602     nullptr,         handshaker_destroy,
603     handshaker_next, handshaker_shutdown};
604 
605 static const tsi_handshaker_vtable handshaker_vtable_dedicated = {
606     nullptr,
607     nullptr,
608     nullptr,
609     nullptr,
610     nullptr,
611     handshaker_destroy,
612     handshaker_next_dedicated,
613     handshaker_shutdown};
614 
alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker * handshaker)615 bool alts_tsi_handshaker_has_shutdown(alts_tsi_handshaker* handshaker) {
616   GPR_ASSERT(handshaker != nullptr);
617   grpc_core::MutexLock lock(&handshaker->mu);
618   return handshaker->shutdown;
619 }
620 
alts_tsi_handshaker_create(const grpc_alts_credentials_options * options,const char * target_name,const char * handshaker_service_url,bool is_client,grpc_pollset_set * interested_parties,tsi_handshaker ** self,size_t user_specified_max_frame_size)621 tsi_result alts_tsi_handshaker_create(
622     const grpc_alts_credentials_options* options, const char* target_name,
623     const char* handshaker_service_url, bool is_client,
624     grpc_pollset_set* interested_parties, tsi_handshaker** self,
625     size_t user_specified_max_frame_size) {
626   if (handshaker_service_url == nullptr || self == nullptr ||
627       options == nullptr || (is_client && target_name == nullptr)) {
628     gpr_log(GPR_ERROR, "Invalid arguments to alts_tsi_handshaker_create()");
629     return TSI_INVALID_ARGUMENT;
630   }
631   alts_tsi_handshaker* handshaker =
632       static_cast<alts_tsi_handshaker*>(gpr_zalloc(sizeof(*handshaker)));
633   gpr_mu_init(&handshaker->mu);
634   handshaker->use_dedicated_cq = interested_parties == nullptr;
635   handshaker->client = nullptr;
636   handshaker->is_client = is_client;
637   handshaker->has_sent_start_message = false;
638   handshaker->target_name = target_name == nullptr
639                                 ? grpc_empty_slice()
640                                 : grpc_slice_from_static_string(target_name);
641   handshaker->interested_parties = interested_parties;
642   handshaker->has_created_handshaker_client = false;
643   handshaker->handshaker_service_url = gpr_strdup(handshaker_service_url);
644   handshaker->options = grpc_alts_credentials_options_copy(options);
645   handshaker->max_frame_size = user_specified_max_frame_size != 0
646                                    ? user_specified_max_frame_size
647                                    : kTsiAltsMaxFrameSize;
648   handshaker->base.vtable = handshaker->use_dedicated_cq
649                                 ? &handshaker_vtable_dedicated
650                                 : &handshaker_vtable;
651   *self = &handshaker->base;
652   return TSI_OK;
653 }
654 
alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result * result,grpc_slice * recv_bytes,size_t bytes_consumed)655 void alts_tsi_handshaker_result_set_unused_bytes(tsi_handshaker_result* result,
656                                                  grpc_slice* recv_bytes,
657                                                  size_t bytes_consumed) {
658   GPR_ASSERT(recv_bytes != nullptr && result != nullptr);
659   if (GRPC_SLICE_LENGTH(*recv_bytes) == bytes_consumed) {
660     return;
661   }
662   alts_tsi_handshaker_result* sresult =
663       reinterpret_cast<alts_tsi_handshaker_result*>(result);
664   sresult->unused_bytes_size = GRPC_SLICE_LENGTH(*recv_bytes) - bytes_consumed;
665   sresult->unused_bytes =
666       static_cast<unsigned char*>(gpr_zalloc(sresult->unused_bytes_size));
667   memcpy(sresult->unused_bytes,
668          GRPC_SLICE_START_PTR(*recv_bytes) + bytes_consumed,
669          sresult->unused_bytes_size);
670 }
671 
672 namespace grpc_core {
673 namespace internal {
674 
alts_tsi_handshaker_get_has_sent_start_message_for_testing(alts_tsi_handshaker * handshaker)675 bool alts_tsi_handshaker_get_has_sent_start_message_for_testing(
676     alts_tsi_handshaker* handshaker) {
677   GPR_ASSERT(handshaker != nullptr);
678   return handshaker->has_sent_start_message;
679 }
680 
alts_tsi_handshaker_set_client_vtable_for_testing(alts_tsi_handshaker * handshaker,alts_handshaker_client_vtable * vtable)681 void alts_tsi_handshaker_set_client_vtable_for_testing(
682     alts_tsi_handshaker* handshaker, alts_handshaker_client_vtable* vtable) {
683   GPR_ASSERT(handshaker != nullptr);
684   handshaker->client_vtable_for_testing = vtable;
685 }
686 
alts_tsi_handshaker_get_is_client_for_testing(alts_tsi_handshaker * handshaker)687 bool alts_tsi_handshaker_get_is_client_for_testing(
688     alts_tsi_handshaker* handshaker) {
689   GPR_ASSERT(handshaker != nullptr);
690   return handshaker->is_client;
691 }
692 
alts_tsi_handshaker_get_client_for_testing(alts_tsi_handshaker * handshaker)693 alts_handshaker_client* alts_tsi_handshaker_get_client_for_testing(
694     alts_tsi_handshaker* handshaker) {
695   return handshaker->client;
696 }
697 
698 }  // namespace internal
699 }  // namespace grpc_core
700