1 /*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <grpc/support/port_platform.h>
20
21 #include <list>
22
23 #include "src/core/tsi/alts/handshaker/alts_handshaker_client.h"
24
25 #include "upb/upb.hpp"
26
27 #include <grpc/byte_buffer.h>
28 #include <grpc/support/alloc.h>
29 #include <grpc/support/log.h>
30
31 #include "src/core/lib/gprpp/sync.h"
32 #include "src/core/lib/slice/slice_internal.h"
33 #include "src/core/lib/surface/call.h"
34 #include "src/core/lib/surface/channel.h"
35 #include "src/core/tsi/alts/handshaker/alts_shared_resource.h"
36 #include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h"
37 #include "src/core/tsi/alts/handshaker/alts_tsi_utils.h"
38
39 #define TSI_ALTS_INITIAL_BUFFER_SIZE 256
40
41 const int kHandshakerClientOpNum = 4;
42
43 struct alts_handshaker_client {
44 const alts_handshaker_client_vtable* vtable;
45 };
46
47 struct recv_message_result {
48 tsi_result status;
49 const unsigned char* bytes_to_send;
50 size_t bytes_to_send_size;
51 tsi_handshaker_result* result;
52 };
53
54 typedef struct alts_grpc_handshaker_client {
55 alts_handshaker_client base;
56 /* One ref is held by the entity that created this handshaker_client, and
57 * another ref is held by the pending RECEIVE_STATUS_ON_CLIENT op. */
58 gpr_refcount refs;
59 alts_tsi_handshaker* handshaker;
60 grpc_call* call;
61 /* A pointer to a function handling the interaction with handshaker service.
62 * That is, it points to grpc_call_start_batch_and_execute when the handshaker
63 * client is used in a non-testing use case and points to a custom function
64 * that validates the data to be sent to handshaker service in a testing use
65 * case. */
66 alts_grpc_caller grpc_caller;
67 /* A gRPC closure to be scheduled when the response from handshaker service
68 * is received. It will be initialized with the injected grpc RPC callback. */
69 grpc_closure on_handshaker_service_resp_recv;
70 /* Buffers containing information to be sent (or received) to (or from) the
71 * handshaker service. */
72 grpc_byte_buffer* send_buffer = nullptr;
73 grpc_byte_buffer* recv_buffer = nullptr;
74 grpc_status_code status = GRPC_STATUS_OK;
75 /* Initial metadata to be received from handshaker service. */
76 grpc_metadata_array recv_initial_metadata;
77 /* A callback function provided by an application to be invoked when response
78 * is received from handshaker service. */
79 tsi_handshaker_on_next_done_cb cb;
80 void* user_data;
81 /* ALTS credential options passed in from the caller. */
82 grpc_alts_credentials_options* options;
83 /* target name information to be passed to handshaker service for server
84 * authorization check. */
85 grpc_slice target_name;
86 /* boolean flag indicating if the handshaker client is used at client
87 * (is_client = true) or server (is_client = false) side. */
88 bool is_client;
89 /* a temporary store for data received from handshaker service used to extract
90 * unused data. */
91 grpc_slice recv_bytes;
92 /* a buffer containing data to be sent to the grpc client or server's peer. */
93 unsigned char* buffer;
94 size_t buffer_size;
95 /** callback for receiving handshake call status */
96 grpc_closure on_status_received;
97 /** gRPC status code of handshake call */
98 grpc_status_code handshake_status_code = GRPC_STATUS_OK;
99 /** gRPC status details of handshake call */
100 grpc_slice handshake_status_details;
101 /* mu synchronizes all fields below including their internal fields. */
102 grpc_core::Mutex mu;
103 /* indicates if the handshaker call's RECV_STATUS_ON_CLIENT op is done. */
104 bool receive_status_finished = false;
105 /* if non-null, contains arguments to complete a TSI next callback. */
106 recv_message_result* pending_recv_message_result = nullptr;
107 /* Maximum frame size used by frame protector. */
108 size_t max_frame_size;
109 } alts_grpc_handshaker_client;
110
handshaker_client_send_buffer_destroy(alts_grpc_handshaker_client * client)111 static void handshaker_client_send_buffer_destroy(
112 alts_grpc_handshaker_client* client) {
113 GPR_ASSERT(client != nullptr);
114 grpc_byte_buffer_destroy(client->send_buffer);
115 client->send_buffer = nullptr;
116 }
117
is_handshake_finished_properly(grpc_gcp_HandshakerResp * resp)118 static bool is_handshake_finished_properly(grpc_gcp_HandshakerResp* resp) {
119 GPR_ASSERT(resp != nullptr);
120 if (grpc_gcp_HandshakerResp_result(resp)) {
121 return true;
122 }
123 return false;
124 }
125
alts_grpc_handshaker_client_unref(alts_grpc_handshaker_client * client)126 static void alts_grpc_handshaker_client_unref(
127 alts_grpc_handshaker_client* client) {
128 if (gpr_unref(&client->refs)) {
129 if (client->base.vtable != nullptr &&
130 client->base.vtable->destruct != nullptr) {
131 client->base.vtable->destruct(&client->base);
132 }
133 grpc_byte_buffer_destroy(client->send_buffer);
134 grpc_byte_buffer_destroy(client->recv_buffer);
135 client->send_buffer = nullptr;
136 client->recv_buffer = nullptr;
137 grpc_metadata_array_destroy(&client->recv_initial_metadata);
138 grpc_slice_unref_internal(client->recv_bytes);
139 grpc_slice_unref_internal(client->target_name);
140 grpc_alts_credentials_options_destroy(client->options);
141 gpr_free(client->buffer);
142 grpc_slice_unref_internal(client->handshake_status_details);
143 delete client;
144 }
145 }
146
maybe_complete_tsi_next(alts_grpc_handshaker_client * client,bool receive_status_finished,recv_message_result * pending_recv_message_result)147 static void maybe_complete_tsi_next(
148 alts_grpc_handshaker_client* client, bool receive_status_finished,
149 recv_message_result* pending_recv_message_result) {
150 recv_message_result* r;
151 {
152 grpc_core::MutexLock lock(&client->mu);
153 client->receive_status_finished |= receive_status_finished;
154 if (pending_recv_message_result != nullptr) {
155 GPR_ASSERT(client->pending_recv_message_result == nullptr);
156 client->pending_recv_message_result = pending_recv_message_result;
157 }
158 if (client->pending_recv_message_result == nullptr) {
159 return;
160 }
161 const bool have_final_result =
162 client->pending_recv_message_result->result != nullptr ||
163 client->pending_recv_message_result->status != TSI_OK;
164 if (have_final_result && !client->receive_status_finished) {
165 // If we've received the final message from the handshake
166 // server, or we're about to invoke the TSI next callback
167 // with a status other than TSI_OK (which terminates the
168 // handshake), then first wait for the RECV_STATUS op to complete.
169 return;
170 }
171 r = client->pending_recv_message_result;
172 client->pending_recv_message_result = nullptr;
173 }
174 client->cb(r->status, client->user_data, r->bytes_to_send,
175 r->bytes_to_send_size, r->result);
176 gpr_free(r);
177 }
178
handle_response_done(alts_grpc_handshaker_client * client,tsi_result status,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * result)179 static void handle_response_done(alts_grpc_handshaker_client* client,
180 tsi_result status,
181 const unsigned char* bytes_to_send,
182 size_t bytes_to_send_size,
183 tsi_handshaker_result* result) {
184 recv_message_result* p =
185 static_cast<recv_message_result*>(gpr_zalloc(sizeof(*p)));
186 p->status = status;
187 p->bytes_to_send = bytes_to_send;
188 p->bytes_to_send_size = bytes_to_send_size;
189 p->result = result;
190 maybe_complete_tsi_next(client, false /* receive_status_finished */,
191 p /* pending_recv_message_result */);
192 }
193
alts_handshaker_client_handle_response(alts_handshaker_client * c,bool is_ok)194 void alts_handshaker_client_handle_response(alts_handshaker_client* c,
195 bool is_ok) {
196 GPR_ASSERT(c != nullptr);
197 alts_grpc_handshaker_client* client =
198 reinterpret_cast<alts_grpc_handshaker_client*>(c);
199 grpc_byte_buffer* recv_buffer = client->recv_buffer;
200 grpc_status_code status = client->status;
201 alts_tsi_handshaker* handshaker = client->handshaker;
202 /* Invalid input check. */
203 if (client->cb == nullptr) {
204 gpr_log(GPR_ERROR,
205 "client->cb is nullptr in alts_tsi_handshaker_handle_response()");
206 return;
207 }
208 if (handshaker == nullptr) {
209 gpr_log(GPR_ERROR,
210 "handshaker is nullptr in alts_tsi_handshaker_handle_response()");
211 handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
212 return;
213 }
214 /* TSI handshake has been shutdown. */
215 if (alts_tsi_handshaker_has_shutdown(handshaker)) {
216 gpr_log(GPR_ERROR, "TSI handshake shutdown");
217 handle_response_done(client, TSI_HANDSHAKE_SHUTDOWN, nullptr, 0, nullptr);
218 return;
219 }
220 /* Failed grpc call check. */
221 if (!is_ok || status != GRPC_STATUS_OK) {
222 gpr_log(GPR_ERROR, "grpc call made to handshaker service failed");
223 handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
224 return;
225 }
226 if (recv_buffer == nullptr) {
227 gpr_log(GPR_ERROR,
228 "recv_buffer is nullptr in alts_tsi_handshaker_handle_response()");
229 handle_response_done(client, TSI_INTERNAL_ERROR, nullptr, 0, nullptr);
230 return;
231 }
232 upb::Arena arena;
233 grpc_gcp_HandshakerResp* resp =
234 alts_tsi_utils_deserialize_response(recv_buffer, arena.ptr());
235 grpc_byte_buffer_destroy(client->recv_buffer);
236 client->recv_buffer = nullptr;
237 /* Invalid handshaker response check. */
238 if (resp == nullptr) {
239 gpr_log(GPR_ERROR, "alts_tsi_utils_deserialize_response() failed");
240 handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr);
241 return;
242 }
243 const grpc_gcp_HandshakerStatus* resp_status =
244 grpc_gcp_HandshakerResp_status(resp);
245 if (resp_status == nullptr) {
246 gpr_log(GPR_ERROR, "No status in HandshakerResp");
247 handle_response_done(client, TSI_DATA_CORRUPTED, nullptr, 0, nullptr);
248 return;
249 }
250 upb_strview out_frames = grpc_gcp_HandshakerResp_out_frames(resp);
251 unsigned char* bytes_to_send = nullptr;
252 size_t bytes_to_send_size = 0;
253 if (out_frames.size > 0) {
254 bytes_to_send_size = out_frames.size;
255 while (bytes_to_send_size > client->buffer_size) {
256 client->buffer_size *= 2;
257 client->buffer = static_cast<unsigned char*>(
258 gpr_realloc(client->buffer, client->buffer_size));
259 }
260 memcpy(client->buffer, out_frames.data, bytes_to_send_size);
261 bytes_to_send = client->buffer;
262 }
263 tsi_handshaker_result* result = nullptr;
264 if (is_handshake_finished_properly(resp)) {
265 tsi_result status =
266 alts_tsi_handshaker_result_create(resp, client->is_client, &result);
267 if (status != TSI_OK) {
268 gpr_log(GPR_ERROR, "alts_tsi_handshaker_result_create() failed");
269 handle_response_done(client, status, nullptr, 0, nullptr);
270 return;
271 }
272 alts_tsi_handshaker_result_set_unused_bytes(
273 result, &client->recv_bytes,
274 grpc_gcp_HandshakerResp_bytes_consumed(resp));
275 }
276 grpc_status_code code = static_cast<grpc_status_code>(
277 grpc_gcp_HandshakerStatus_code(resp_status));
278 if (code != GRPC_STATUS_OK) {
279 upb_strview details = grpc_gcp_HandshakerStatus_details(resp_status);
280 if (details.size > 0) {
281 char* error_details = static_cast<char*>(gpr_zalloc(details.size + 1));
282 memcpy(error_details, details.data, details.size);
283 gpr_log(GPR_ERROR, "Error from handshaker service:%s", error_details);
284 gpr_free(error_details);
285 }
286 }
287 // TODO(apolcyn): consider short ciruiting handle_response_done and
288 // invoking the TSI callback directly if we aren't done yet, if
289 // handle_response_done's allocation per message received causes
290 // a performance issue.
291 handle_response_done(client, alts_tsi_utils_convert_to_tsi_result(code),
292 bytes_to_send, bytes_to_send_size, result);
293 }
294
continue_make_grpc_call(alts_grpc_handshaker_client * client,bool is_start)295 static tsi_result continue_make_grpc_call(alts_grpc_handshaker_client* client,
296 bool is_start) {
297 GPR_ASSERT(client != nullptr);
298 grpc_op ops[kHandshakerClientOpNum];
299 memset(ops, 0, sizeof(ops));
300 grpc_op* op = ops;
301 if (is_start) {
302 op->op = GRPC_OP_RECV_STATUS_ON_CLIENT;
303 op->data.recv_status_on_client.trailing_metadata = nullptr;
304 op->data.recv_status_on_client.status = &client->handshake_status_code;
305 op->data.recv_status_on_client.status_details =
306 &client->handshake_status_details;
307 op->flags = 0;
308 op->reserved = nullptr;
309 op++;
310 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
311 gpr_ref(&client->refs);
312 grpc_call_error call_error =
313 client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
314 &client->on_status_received);
315 // TODO(apolcyn): return the error here instead, as done for other ops?
316 GPR_ASSERT(call_error == GRPC_CALL_OK);
317 memset(ops, 0, sizeof(ops));
318 op = ops;
319 op->op = GRPC_OP_SEND_INITIAL_METADATA;
320 op->data.send_initial_metadata.count = 0;
321 op++;
322 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
323 op->op = GRPC_OP_RECV_INITIAL_METADATA;
324 op->data.recv_initial_metadata.recv_initial_metadata =
325 &client->recv_initial_metadata;
326 op++;
327 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
328 }
329 op->op = GRPC_OP_SEND_MESSAGE;
330 op->data.send_message.send_message = client->send_buffer;
331 op++;
332 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
333 op->op = GRPC_OP_RECV_MESSAGE;
334 op->data.recv_message.recv_message = &client->recv_buffer;
335 op++;
336 GPR_ASSERT(op - ops <= kHandshakerClientOpNum);
337 GPR_ASSERT(client->grpc_caller != nullptr);
338 if (client->grpc_caller(client->call, ops, static_cast<size_t>(op - ops),
339 &client->on_handshaker_service_resp_recv) !=
340 GRPC_CALL_OK) {
341 gpr_log(GPR_ERROR, "Start batch operation failed");
342 return TSI_INTERNAL_ERROR;
343 }
344 return TSI_OK;
345 }
346
347 // TODO(apolcyn): remove this global queue when we can safely rely
348 // on a MAX_CONCURRENT_STREAMS setting in the ALTS handshake server to
349 // limit the number of concurrent handshakes.
350 namespace {
351
352 class HandshakeQueue {
353 public:
HandshakeQueue(size_t max_outstanding_handshakes)354 explicit HandshakeQueue(size_t max_outstanding_handshakes)
355 : max_outstanding_handshakes_(max_outstanding_handshakes) {}
356
RequestHandshake(alts_grpc_handshaker_client * client)357 void RequestHandshake(alts_grpc_handshaker_client* client) {
358 {
359 grpc_core::MutexLock lock(&mu_);
360 if (outstanding_handshakes_ == max_outstanding_handshakes_) {
361 // Max number already running, add to queue.
362 queued_handshakes_.push_back(client);
363 return;
364 }
365 // Start the handshake immediately.
366 ++outstanding_handshakes_;
367 }
368 continue_make_grpc_call(client, true /* is_start */);
369 }
370
HandshakeDone()371 void HandshakeDone() {
372 alts_grpc_handshaker_client* client = nullptr;
373 {
374 grpc_core::MutexLock lock(&mu_);
375 if (queued_handshakes_.empty()) {
376 // Nothing more in queue. Decrement count and return immediately.
377 --outstanding_handshakes_;
378 return;
379 }
380 // Remove next entry from queue and start the handshake.
381 client = queued_handshakes_.front();
382 queued_handshakes_.pop_front();
383 }
384 continue_make_grpc_call(client, true /* is_start */);
385 }
386
387 private:
388 grpc_core::Mutex mu_;
389 std::list<alts_grpc_handshaker_client*> queued_handshakes_;
390 size_t outstanding_handshakes_ = 0;
391 const size_t max_outstanding_handshakes_;
392 };
393
394 gpr_once g_queued_handshakes_init = GPR_ONCE_INIT;
395 /* Using separate queues for client and server handshakes is a
396 * hack that's mainly intended to satisfy the alts_concurrent_connectivity_test,
397 * which runs many concurrent handshakes where both endpoints
398 * are in the same process; this situation is problematic with a
399 * single queue because we have a high chance of using up all outstanding
400 * slots in the queue, such that there aren't any
401 * mutual client/server handshakes outstanding at the same time and
402 * able to make progress. */
403 HandshakeQueue* g_client_handshake_queue;
404 HandshakeQueue* g_server_handshake_queue;
405
DoHandshakeQueuesInit(void)406 void DoHandshakeQueuesInit(void) {
407 const size_t per_queue_max_outstanding_handshakes = 40;
408 g_client_handshake_queue =
409 new HandshakeQueue(per_queue_max_outstanding_handshakes);
410 g_server_handshake_queue =
411 new HandshakeQueue(per_queue_max_outstanding_handshakes);
412 }
413
RequestHandshake(alts_grpc_handshaker_client * client,bool is_client)414 void RequestHandshake(alts_grpc_handshaker_client* client, bool is_client) {
415 gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
416 HandshakeQueue* queue =
417 is_client ? g_client_handshake_queue : g_server_handshake_queue;
418 queue->RequestHandshake(client);
419 }
420
HandshakeDone(bool is_client)421 void HandshakeDone(bool is_client) {
422 HandshakeQueue* queue =
423 is_client ? g_client_handshake_queue : g_server_handshake_queue;
424 queue->HandshakeDone();
425 }
426
427 }; // namespace
428
429 /**
430 * Populate grpc operation data with the fields of ALTS handshaker client and
431 * make a grpc call.
432 */
make_grpc_call(alts_handshaker_client * c,bool is_start)433 static tsi_result make_grpc_call(alts_handshaker_client* c, bool is_start) {
434 GPR_ASSERT(c != nullptr);
435 alts_grpc_handshaker_client* client =
436 reinterpret_cast<alts_grpc_handshaker_client*>(c);
437 if (is_start) {
438 RequestHandshake(client, client->is_client);
439 return TSI_OK;
440 } else {
441 return continue_make_grpc_call(client, is_start);
442 }
443 }
444
on_status_received(void * arg,grpc_error * error)445 static void on_status_received(void* arg, grpc_error* error) {
446 alts_grpc_handshaker_client* client =
447 static_cast<alts_grpc_handshaker_client*>(arg);
448 if (client->handshake_status_code != GRPC_STATUS_OK) {
449 // TODO(apolcyn): consider overriding the handshake result's
450 // status from the final ALTS message with the status here.
451 char* status_details =
452 grpc_slice_to_c_string(client->handshake_status_details);
453 gpr_log(GPR_INFO,
454 "alts_grpc_handshaker_client:%p on_status_received "
455 "status:%d details:|%s| error:|%s|",
456 client, client->handshake_status_code, status_details,
457 grpc_error_string(error));
458 gpr_free(status_details);
459 }
460 maybe_complete_tsi_next(client, true /* receive_status_finished */,
461 nullptr /* pending_recv_message_result */);
462 HandshakeDone(client->is_client);
463 alts_grpc_handshaker_client_unref(client);
464 }
465
466 /* Serializes a grpc_gcp_HandshakerReq message into a buffer and returns newly
467 * grpc_byte_buffer holding it. */
get_serialized_handshaker_req(grpc_gcp_HandshakerReq * req,upb_arena * arena)468 static grpc_byte_buffer* get_serialized_handshaker_req(
469 grpc_gcp_HandshakerReq* req, upb_arena* arena) {
470 size_t buf_length;
471 char* buf = grpc_gcp_HandshakerReq_serialize(req, arena, &buf_length);
472 if (buf == nullptr) {
473 return nullptr;
474 }
475 grpc_slice slice = grpc_slice_from_copied_buffer(buf, buf_length);
476 grpc_byte_buffer* byte_buffer = grpc_raw_byte_buffer_create(&slice, 1);
477 grpc_slice_unref_internal(slice);
478 return byte_buffer;
479 }
480
481 /* Create and populate a client_start handshaker request, then serialize it. */
get_serialized_start_client(alts_handshaker_client * c)482 static grpc_byte_buffer* get_serialized_start_client(
483 alts_handshaker_client* c) {
484 GPR_ASSERT(c != nullptr);
485 alts_grpc_handshaker_client* client =
486 reinterpret_cast<alts_grpc_handshaker_client*>(c);
487 upb::Arena arena;
488 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
489 grpc_gcp_StartClientHandshakeReq* start_client =
490 grpc_gcp_HandshakerReq_mutable_client_start(req, arena.ptr());
491 grpc_gcp_StartClientHandshakeReq_set_handshake_security_protocol(
492 start_client, grpc_gcp_ALTS);
493 grpc_gcp_StartClientHandshakeReq_add_application_protocols(
494 start_client, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr());
495 grpc_gcp_StartClientHandshakeReq_add_record_protocols(
496 start_client, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr());
497 grpc_gcp_RpcProtocolVersions* client_version =
498 grpc_gcp_StartClientHandshakeReq_mutable_rpc_versions(start_client,
499 arena.ptr());
500 grpc_gcp_RpcProtocolVersions_assign_from_struct(
501 client_version, arena.ptr(), &client->options->rpc_versions);
502 grpc_gcp_StartClientHandshakeReq_set_target_name(
503 start_client,
504 upb_strview_make(reinterpret_cast<const char*>(
505 GRPC_SLICE_START_PTR(client->target_name)),
506 GRPC_SLICE_LENGTH(client->target_name)));
507 target_service_account* ptr =
508 (reinterpret_cast<grpc_alts_credentials_client_options*>(client->options))
509 ->target_account_list_head;
510 while (ptr != nullptr) {
511 grpc_gcp_Identity* target_identity =
512 grpc_gcp_StartClientHandshakeReq_add_target_identities(start_client,
513 arena.ptr());
514 grpc_gcp_Identity_set_service_account(target_identity,
515 upb_strview_makez(ptr->data));
516 ptr = ptr->next;
517 }
518 grpc_gcp_StartClientHandshakeReq_set_max_frame_size(
519 start_client, static_cast<uint32_t>(client->max_frame_size));
520 return get_serialized_handshaker_req(req, arena.ptr());
521 }
522
handshaker_client_start_client(alts_handshaker_client * c)523 static tsi_result handshaker_client_start_client(alts_handshaker_client* c) {
524 if (c == nullptr) {
525 gpr_log(GPR_ERROR, "client is nullptr in handshaker_client_start_client()");
526 return TSI_INVALID_ARGUMENT;
527 }
528 grpc_byte_buffer* buffer = get_serialized_start_client(c);
529 alts_grpc_handshaker_client* client =
530 reinterpret_cast<alts_grpc_handshaker_client*>(c);
531 if (buffer == nullptr) {
532 gpr_log(GPR_ERROR, "get_serialized_start_client() failed");
533 return TSI_INTERNAL_ERROR;
534 }
535 handshaker_client_send_buffer_destroy(client);
536 client->send_buffer = buffer;
537 tsi_result result = make_grpc_call(&client->base, true /* is_start */);
538 if (result != TSI_OK) {
539 gpr_log(GPR_ERROR, "make_grpc_call() failed");
540 }
541 return result;
542 }
543
544 /* Create and populate a start_server handshaker request, then serialize it. */
get_serialized_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)545 static grpc_byte_buffer* get_serialized_start_server(
546 alts_handshaker_client* c, grpc_slice* bytes_received) {
547 GPR_ASSERT(c != nullptr);
548 GPR_ASSERT(bytes_received != nullptr);
549 alts_grpc_handshaker_client* client =
550 reinterpret_cast<alts_grpc_handshaker_client*>(c);
551
552 upb::Arena arena;
553 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
554
555 grpc_gcp_StartServerHandshakeReq* start_server =
556 grpc_gcp_HandshakerReq_mutable_server_start(req, arena.ptr());
557 grpc_gcp_StartServerHandshakeReq_add_application_protocols(
558 start_server, upb_strview_makez(ALTS_APPLICATION_PROTOCOL), arena.ptr());
559 grpc_gcp_ServerHandshakeParameters* value =
560 grpc_gcp_ServerHandshakeParameters_new(arena.ptr());
561 grpc_gcp_ServerHandshakeParameters_add_record_protocols(
562 value, upb_strview_makez(ALTS_RECORD_PROTOCOL), arena.ptr());
563 grpc_gcp_StartServerHandshakeReq_handshake_parameters_set(
564 start_server, grpc_gcp_ALTS, value, arena.ptr());
565 grpc_gcp_StartServerHandshakeReq_set_in_bytes(
566 start_server, upb_strview_make(reinterpret_cast<const char*>(
567 GRPC_SLICE_START_PTR(*bytes_received)),
568 GRPC_SLICE_LENGTH(*bytes_received)));
569 grpc_gcp_RpcProtocolVersions* server_version =
570 grpc_gcp_StartServerHandshakeReq_mutable_rpc_versions(start_server,
571 arena.ptr());
572 grpc_gcp_RpcProtocolVersions_assign_from_struct(
573 server_version, arena.ptr(), &client->options->rpc_versions);
574 grpc_gcp_StartServerHandshakeReq_set_max_frame_size(
575 start_server, static_cast<uint32_t>(client->max_frame_size));
576 return get_serialized_handshaker_req(req, arena.ptr());
577 }
578
handshaker_client_start_server(alts_handshaker_client * c,grpc_slice * bytes_received)579 static tsi_result handshaker_client_start_server(alts_handshaker_client* c,
580 grpc_slice* bytes_received) {
581 if (c == nullptr || bytes_received == nullptr) {
582 gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_start_server()");
583 return TSI_INVALID_ARGUMENT;
584 }
585 alts_grpc_handshaker_client* client =
586 reinterpret_cast<alts_grpc_handshaker_client*>(c);
587 grpc_byte_buffer* buffer = get_serialized_start_server(c, bytes_received);
588 if (buffer == nullptr) {
589 gpr_log(GPR_ERROR, "get_serialized_start_server() failed");
590 return TSI_INTERNAL_ERROR;
591 }
592 handshaker_client_send_buffer_destroy(client);
593 client->send_buffer = buffer;
594 tsi_result result = make_grpc_call(&client->base, true /* is_start */);
595 if (result != TSI_OK) {
596 gpr_log(GPR_ERROR, "make_grpc_call() failed");
597 }
598 return result;
599 }
600
601 /* Create and populate a next handshaker request, then serialize it. */
get_serialized_next(grpc_slice * bytes_received)602 static grpc_byte_buffer* get_serialized_next(grpc_slice* bytes_received) {
603 GPR_ASSERT(bytes_received != nullptr);
604 upb::Arena arena;
605 grpc_gcp_HandshakerReq* req = grpc_gcp_HandshakerReq_new(arena.ptr());
606 grpc_gcp_NextHandshakeMessageReq* next =
607 grpc_gcp_HandshakerReq_mutable_next(req, arena.ptr());
608 grpc_gcp_NextHandshakeMessageReq_set_in_bytes(
609 next, upb_strview_make(reinterpret_cast<const char*> GRPC_SLICE_START_PTR(
610 *bytes_received),
611 GRPC_SLICE_LENGTH(*bytes_received)));
612 return get_serialized_handshaker_req(req, arena.ptr());
613 }
614
handshaker_client_next(alts_handshaker_client * c,grpc_slice * bytes_received)615 static tsi_result handshaker_client_next(alts_handshaker_client* c,
616 grpc_slice* bytes_received) {
617 if (c == nullptr || bytes_received == nullptr) {
618 gpr_log(GPR_ERROR, "Invalid arguments to handshaker_client_next()");
619 return TSI_INVALID_ARGUMENT;
620 }
621 alts_grpc_handshaker_client* client =
622 reinterpret_cast<alts_grpc_handshaker_client*>(c);
623 grpc_slice_unref_internal(client->recv_bytes);
624 client->recv_bytes = grpc_slice_ref_internal(*bytes_received);
625 grpc_byte_buffer* buffer = get_serialized_next(bytes_received);
626 if (buffer == nullptr) {
627 gpr_log(GPR_ERROR, "get_serialized_next() failed");
628 return TSI_INTERNAL_ERROR;
629 }
630 handshaker_client_send_buffer_destroy(client);
631 client->send_buffer = buffer;
632 tsi_result result = make_grpc_call(&client->base, false /* is_start */);
633 if (result != TSI_OK) {
634 gpr_log(GPR_ERROR, "make_grpc_call() failed");
635 }
636 return result;
637 }
638
handshaker_client_shutdown(alts_handshaker_client * c)639 static void handshaker_client_shutdown(alts_handshaker_client* c) {
640 GPR_ASSERT(c != nullptr);
641 alts_grpc_handshaker_client* client =
642 reinterpret_cast<alts_grpc_handshaker_client*>(c);
643 if (client->call != nullptr) {
644 grpc_call_cancel_internal(client->call);
645 }
646 }
647
handshaker_call_unref(void * arg,grpc_error *)648 static void handshaker_call_unref(void* arg, grpc_error* /* error */) {
649 grpc_call* call = static_cast<grpc_call*>(arg);
650 grpc_call_unref(call);
651 }
652
handshaker_client_destruct(alts_handshaker_client * c)653 static void handshaker_client_destruct(alts_handshaker_client* c) {
654 if (c == nullptr) {
655 return;
656 }
657 alts_grpc_handshaker_client* client =
658 reinterpret_cast<alts_grpc_handshaker_client*>(c);
659 if (client->call != nullptr) {
660 // Throw this grpc_call_unref over to the ExecCtx so that
661 // we invoke it at the bottom of the call stack and
662 // prevent lock inversion problems due to nested ExecCtx flushing.
663 // TODO(apolcyn): we could remove this indirection and call
664 // grpc_call_unref inline if there was an internal variant of
665 // grpc_call_unref that didn't need to flush an ExecCtx.
666 if (grpc_core::ExecCtx::Get() == nullptr) {
667 // Unref handshaker call if there is no exec_ctx, e.g., in the case of
668 // Envoy ALTS transport socket.
669 grpc_call_unref(client->call);
670 } else {
671 // Using existing exec_ctx to unref handshaker call.
672 grpc_core::ExecCtx::Run(
673 DEBUG_LOCATION,
674 GRPC_CLOSURE_CREATE(handshaker_call_unref, client->call,
675 grpc_schedule_on_exec_ctx),
676 GRPC_ERROR_NONE);
677 }
678 }
679 }
680
681 static const alts_handshaker_client_vtable vtable = {
682 handshaker_client_start_client, handshaker_client_start_server,
683 handshaker_client_next, handshaker_client_shutdown,
684 handshaker_client_destruct};
685
alts_grpc_handshaker_client_create(alts_tsi_handshaker * handshaker,grpc_channel * channel,const char * handshaker_service_url,grpc_pollset_set * interested_parties,grpc_alts_credentials_options * options,const grpc_slice & target_name,grpc_iomgr_cb_func grpc_cb,tsi_handshaker_on_next_done_cb cb,void * user_data,alts_handshaker_client_vtable * vtable_for_testing,bool is_client,size_t max_frame_size)686 alts_handshaker_client* alts_grpc_handshaker_client_create(
687 alts_tsi_handshaker* handshaker, grpc_channel* channel,
688 const char* handshaker_service_url, grpc_pollset_set* interested_parties,
689 grpc_alts_credentials_options* options, const grpc_slice& target_name,
690 grpc_iomgr_cb_func grpc_cb, tsi_handshaker_on_next_done_cb cb,
691 void* user_data, alts_handshaker_client_vtable* vtable_for_testing,
692 bool is_client, size_t max_frame_size) {
693 if (channel == nullptr || handshaker_service_url == nullptr) {
694 gpr_log(GPR_ERROR, "Invalid arguments to alts_handshaker_client_create()");
695 return nullptr;
696 }
697 alts_grpc_handshaker_client* client = new alts_grpc_handshaker_client();
698 memset(&client->base, 0, sizeof(client->base));
699 client->base.vtable =
700 vtable_for_testing == nullptr ? &vtable : vtable_for_testing;
701 gpr_ref_init(&client->refs, 1);
702 client->handshaker = handshaker;
703 client->grpc_caller = grpc_call_start_batch_and_execute;
704 grpc_metadata_array_init(&client->recv_initial_metadata);
705 client->cb = cb;
706 client->user_data = user_data;
707 client->options = grpc_alts_credentials_options_copy(options);
708 client->target_name = grpc_slice_copy(target_name);
709 client->is_client = is_client;
710 client->recv_bytes = grpc_empty_slice();
711 client->buffer_size = TSI_ALTS_INITIAL_BUFFER_SIZE;
712 client->buffer = static_cast<unsigned char*>(gpr_zalloc(client->buffer_size));
713 client->handshake_status_details = grpc_empty_slice();
714 client->max_frame_size = max_frame_size;
715 grpc_slice slice = grpc_slice_from_copied_string(handshaker_service_url);
716 client->call =
717 strcmp(handshaker_service_url, ALTS_HANDSHAKER_SERVICE_URL_FOR_TESTING) ==
718 0
719 ? nullptr
720 : grpc_channel_create_pollset_set_call(
721 channel, nullptr, GRPC_PROPAGATE_DEFAULTS, interested_parties,
722 grpc_slice_from_static_string(ALTS_SERVICE_METHOD), &slice,
723 GRPC_MILLIS_INF_FUTURE, nullptr);
724 GRPC_CLOSURE_INIT(&client->on_handshaker_service_resp_recv, grpc_cb, client,
725 grpc_schedule_on_exec_ctx);
726 GRPC_CLOSURE_INIT(&client->on_status_received, on_status_received, client,
727 grpc_schedule_on_exec_ctx);
728 grpc_slice_unref_internal(slice);
729 return &client->base;
730 }
731
732 namespace grpc_core {
733 namespace internal {
734
alts_handshaker_client_set_grpc_caller_for_testing(alts_handshaker_client * c,alts_grpc_caller caller)735 void alts_handshaker_client_set_grpc_caller_for_testing(
736 alts_handshaker_client* c, alts_grpc_caller caller) {
737 GPR_ASSERT(c != nullptr && caller != nullptr);
738 alts_grpc_handshaker_client* client =
739 reinterpret_cast<alts_grpc_handshaker_client*>(c);
740 client->grpc_caller = caller;
741 }
742
alts_handshaker_client_get_send_buffer_for_testing(alts_handshaker_client * c)743 grpc_byte_buffer* alts_handshaker_client_get_send_buffer_for_testing(
744 alts_handshaker_client* c) {
745 GPR_ASSERT(c != nullptr);
746 alts_grpc_handshaker_client* client =
747 reinterpret_cast<alts_grpc_handshaker_client*>(c);
748 return client->send_buffer;
749 }
750
alts_handshaker_client_get_recv_buffer_addr_for_testing(alts_handshaker_client * c)751 grpc_byte_buffer** alts_handshaker_client_get_recv_buffer_addr_for_testing(
752 alts_handshaker_client* c) {
753 GPR_ASSERT(c != nullptr);
754 alts_grpc_handshaker_client* client =
755 reinterpret_cast<alts_grpc_handshaker_client*>(c);
756 return &client->recv_buffer;
757 }
758
alts_handshaker_client_get_initial_metadata_for_testing(alts_handshaker_client * c)759 grpc_metadata_array* alts_handshaker_client_get_initial_metadata_for_testing(
760 alts_handshaker_client* c) {
761 GPR_ASSERT(c != nullptr);
762 alts_grpc_handshaker_client* client =
763 reinterpret_cast<alts_grpc_handshaker_client*>(c);
764 return &client->recv_initial_metadata;
765 }
766
alts_handshaker_client_set_recv_bytes_for_testing(alts_handshaker_client * c,grpc_slice * recv_bytes)767 void alts_handshaker_client_set_recv_bytes_for_testing(
768 alts_handshaker_client* c, grpc_slice* recv_bytes) {
769 GPR_ASSERT(c != nullptr);
770 alts_grpc_handshaker_client* client =
771 reinterpret_cast<alts_grpc_handshaker_client*>(c);
772 client->recv_bytes = grpc_slice_ref_internal(*recv_bytes);
773 }
774
alts_handshaker_client_set_fields_for_testing(alts_handshaker_client * c,alts_tsi_handshaker * handshaker,tsi_handshaker_on_next_done_cb cb,void * user_data,grpc_byte_buffer * recv_buffer,grpc_status_code status)775 void alts_handshaker_client_set_fields_for_testing(
776 alts_handshaker_client* c, alts_tsi_handshaker* handshaker,
777 tsi_handshaker_on_next_done_cb cb, void* user_data,
778 grpc_byte_buffer* recv_buffer, grpc_status_code status) {
779 GPR_ASSERT(c != nullptr);
780 alts_grpc_handshaker_client* client =
781 reinterpret_cast<alts_grpc_handshaker_client*>(c);
782 client->handshaker = handshaker;
783 client->cb = cb;
784 client->user_data = user_data;
785 client->recv_buffer = recv_buffer;
786 client->status = status;
787 }
788
alts_handshaker_client_check_fields_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb,void * user_data,bool has_sent_start_message,grpc_slice * recv_bytes)789 void alts_handshaker_client_check_fields_for_testing(
790 alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb,
791 void* user_data, bool has_sent_start_message, grpc_slice* recv_bytes) {
792 GPR_ASSERT(c != nullptr);
793 alts_grpc_handshaker_client* client =
794 reinterpret_cast<alts_grpc_handshaker_client*>(c);
795 GPR_ASSERT(client->cb == cb);
796 GPR_ASSERT(client->user_data == user_data);
797 if (recv_bytes != nullptr) {
798 GPR_ASSERT(grpc_slice_cmp(client->recv_bytes, *recv_bytes) == 0);
799 }
800 GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing(
801 client->handshaker) == has_sent_start_message);
802 }
803
alts_handshaker_client_set_vtable_for_testing(alts_handshaker_client * c,alts_handshaker_client_vtable * vtable)804 void alts_handshaker_client_set_vtable_for_testing(
805 alts_handshaker_client* c, alts_handshaker_client_vtable* vtable) {
806 GPR_ASSERT(c != nullptr);
807 GPR_ASSERT(vtable != nullptr);
808 alts_grpc_handshaker_client* client =
809 reinterpret_cast<alts_grpc_handshaker_client*>(c);
810 client->base.vtable = vtable;
811 }
812
alts_handshaker_client_get_handshaker_for_testing(alts_handshaker_client * c)813 alts_tsi_handshaker* alts_handshaker_client_get_handshaker_for_testing(
814 alts_handshaker_client* c) {
815 GPR_ASSERT(c != nullptr);
816 alts_grpc_handshaker_client* client =
817 reinterpret_cast<alts_grpc_handshaker_client*>(c);
818 return client->handshaker;
819 }
820
alts_handshaker_client_set_cb_for_testing(alts_handshaker_client * c,tsi_handshaker_on_next_done_cb cb)821 void alts_handshaker_client_set_cb_for_testing(
822 alts_handshaker_client* c, tsi_handshaker_on_next_done_cb cb) {
823 GPR_ASSERT(c != nullptr);
824 alts_grpc_handshaker_client* client =
825 reinterpret_cast<alts_grpc_handshaker_client*>(c);
826 client->cb = cb;
827 }
828
alts_handshaker_client_get_closure_for_testing(alts_handshaker_client * c)829 grpc_closure* alts_handshaker_client_get_closure_for_testing(
830 alts_handshaker_client* c) {
831 GPR_ASSERT(c != nullptr);
832 alts_grpc_handshaker_client* client =
833 reinterpret_cast<alts_grpc_handshaker_client*>(c);
834 return &client->on_handshaker_service_resp_recv;
835 }
836
alts_handshaker_client_ref_for_testing(alts_handshaker_client * c)837 void alts_handshaker_client_ref_for_testing(alts_handshaker_client* c) {
838 alts_grpc_handshaker_client* client =
839 reinterpret_cast<alts_grpc_handshaker_client*>(c);
840 gpr_ref(&client->refs);
841 }
842
alts_handshaker_client_on_status_received_for_testing(alts_handshaker_client * c,grpc_status_code status,grpc_error * error)843 void alts_handshaker_client_on_status_received_for_testing(
844 alts_handshaker_client* c, grpc_status_code status, grpc_error* error) {
845 // We first make sure that the handshake queue has been initialized
846 // here because there are tests that use this API that mock out
847 // other parts of the alts_handshaker_client in such a way that the
848 // code path that would normally ensure that the handshake queue
849 // has been initialized isn't taken.
850 gpr_once_init(&g_queued_handshakes_init, DoHandshakeQueuesInit);
851 alts_grpc_handshaker_client* client =
852 reinterpret_cast<alts_grpc_handshaker_client*>(c);
853 client->handshake_status_code = status;
854 client->handshake_status_details = grpc_empty_slice();
855 grpc_core::Closure::Run(DEBUG_LOCATION, &client->on_status_received, error);
856 }
857
858 } // namespace internal
859 } // namespace grpc_core
860
alts_handshaker_client_start_client(alts_handshaker_client * client)861 tsi_result alts_handshaker_client_start_client(alts_handshaker_client* client) {
862 if (client != nullptr && client->vtable != nullptr &&
863 client->vtable->client_start != nullptr) {
864 return client->vtable->client_start(client);
865 }
866 gpr_log(GPR_ERROR,
867 "client or client->vtable has not been initialized properly");
868 return TSI_INVALID_ARGUMENT;
869 }
870
alts_handshaker_client_start_server(alts_handshaker_client * client,grpc_slice * bytes_received)871 tsi_result alts_handshaker_client_start_server(alts_handshaker_client* client,
872 grpc_slice* bytes_received) {
873 if (client != nullptr && client->vtable != nullptr &&
874 client->vtable->server_start != nullptr) {
875 return client->vtable->server_start(client, bytes_received);
876 }
877 gpr_log(GPR_ERROR,
878 "client or client->vtable has not been initialized properly");
879 return TSI_INVALID_ARGUMENT;
880 }
881
alts_handshaker_client_next(alts_handshaker_client * client,grpc_slice * bytes_received)882 tsi_result alts_handshaker_client_next(alts_handshaker_client* client,
883 grpc_slice* bytes_received) {
884 if (client != nullptr && client->vtable != nullptr &&
885 client->vtable->next != nullptr) {
886 return client->vtable->next(client, bytes_received);
887 }
888 gpr_log(GPR_ERROR,
889 "client or client->vtable has not been initialized properly");
890 return TSI_INVALID_ARGUMENT;
891 }
892
alts_handshaker_client_shutdown(alts_handshaker_client * client)893 void alts_handshaker_client_shutdown(alts_handshaker_client* client) {
894 if (client != nullptr && client->vtable != nullptr &&
895 client->vtable->shutdown != nullptr) {
896 client->vtable->shutdown(client);
897 }
898 }
899
alts_handshaker_client_destroy(alts_handshaker_client * c)900 void alts_handshaker_client_destroy(alts_handshaker_client* c) {
901 if (c != nullptr) {
902 alts_grpc_handshaker_client* client =
903 reinterpret_cast<alts_grpc_handshaker_client*>(c);
904 alts_grpc_handshaker_client_unref(client);
905 }
906 }
907