1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/handshaker/security/security_handshaker.h"
20
21 #include <grpc/grpc_security.h>
22 #include <grpc/grpc_security_constants.h>
23 #include <grpc/impl/channel_arg_names.h>
24 #include <grpc/slice.h>
25 #include <grpc/slice_buffer.h>
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/port_platform.h>
28 #include <limits.h>
29 #include <stdint.h>
30 #include <string.h>
31
32 #include <algorithm>
33 #include <memory>
34 #include <string>
35 #include <utility>
36
37 #include "absl/base/attributes.h"
38 #include "absl/functional/any_invocable.h"
39 #include "absl/log/check.h"
40 #include "absl/status/status.h"
41 #include "absl/strings/str_cat.h"
42 #include "absl/strings/string_view.h"
43 #include "absl/types/optional.h"
44 #include "src/core/channelz/channelz.h"
45 #include "src/core/config/core_configuration.h"
46 #include "src/core/handshaker/handshaker.h"
47 #include "src/core/handshaker/handshaker_factory.h"
48 #include "src/core/handshaker/handshaker_registry.h"
49 #include "src/core/handshaker/security/secure_endpoint.h"
50 #include "src/core/lib/channel/channel_args.h"
51 #include "src/core/lib/iomgr/closure.h"
52 #include "src/core/lib/iomgr/endpoint.h"
53 #include "src/core/lib/iomgr/error.h"
54 #include "src/core/lib/iomgr/exec_ctx.h"
55 #include "src/core/lib/iomgr/iomgr_fwd.h"
56 #include "src/core/lib/iomgr/tcp_server.h"
57 #include "src/core/lib/security/context/security_context.h"
58 #include "src/core/lib/slice/slice.h"
59 #include "src/core/lib/slice/slice_internal.h"
60 #include "src/core/telemetry/stats.h"
61 #include "src/core/telemetry/stats_data.h"
62 #include "src/core/tsi/transport_security_grpc.h"
63 #include "src/core/util/debug_location.h"
64 #include "src/core/util/ref_counted_ptr.h"
65 #include "src/core/util/sync.h"
66 #include "src/core/util/unique_type_name.h"
67
68 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
69
70 namespace grpc_core {
71
72 namespace {
73
74 class SecurityHandshaker : public Handshaker {
75 public:
76 SecurityHandshaker(tsi_handshaker* handshaker,
77 grpc_security_connector* connector,
78 const ChannelArgs& args);
79 ~SecurityHandshaker() override;
name() const80 absl::string_view name() const override { return "security"; }
81 void DoHandshake(
82 HandshakerArgs* args,
83 absl::AnyInvocable<void(absl::Status)> on_handshake_done) override;
84 void Shutdown(absl::Status error) override;
85
86 private:
87 grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received,
88 size_t bytes_received_size)
89 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
90
91 grpc_error_handle OnHandshakeNextDoneLocked(
92 tsi_result result, const unsigned char* bytes_to_send,
93 size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result)
94 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
95 void HandshakeFailedLocked(absl::Status error)
96 ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
97 void Finish(absl::Status status);
98
99 void OnHandshakeDataReceivedFromPeerFn(absl::Status error);
100 void OnHandshakeDataSentToPeerFn(absl::Status error);
101 void OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error);
102 void OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error);
103 static void OnHandshakeNextDoneGrpcWrapper(
104 tsi_result result, void* user_data, const unsigned char* bytes_to_send,
105 size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
106 void OnPeerCheckedFn(grpc_error_handle error);
107 size_t MoveReadBufferIntoHandshakeBuffer();
108 grpc_error_handle CheckPeerLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
109
110 // State set at creation time.
111 tsi_handshaker* handshaker_;
112 RefCountedPtr<grpc_security_connector> connector_;
113
114 Mutex mu_;
115
116 bool is_shutdown_ = false;
117
118 // State saved while performing the handshake.
119 HandshakerArgs* args_ = nullptr;
120 absl::AnyInvocable<void(absl::Status)> on_handshake_done_;
121
122 size_t handshake_buffer_size_;
123 unsigned char* handshake_buffer_;
124 SliceBuffer outgoing_;
125 RefCountedPtr<grpc_auth_context> auth_context_;
126 tsi_handshaker_result* handshaker_result_ = nullptr;
127 size_t max_frame_size_ = 0;
128 std::string tsi_handshake_error_;
129 grpc_closure* on_peer_checked_ ABSL_GUARDED_BY(mu_) = nullptr;
130 };
131
SecurityHandshaker(tsi_handshaker * handshaker,grpc_security_connector * connector,const ChannelArgs & args)132 SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
133 grpc_security_connector* connector,
134 const ChannelArgs& args)
135 : handshaker_(handshaker),
136 connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
137 handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
138 handshake_buffer_(
139 static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
140 max_frame_size_(
141 std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {}
142
~SecurityHandshaker()143 SecurityHandshaker::~SecurityHandshaker() {
144 tsi_handshaker_destroy(handshaker_);
145 tsi_handshaker_result_destroy(handshaker_result_);
146 gpr_free(handshake_buffer_);
147 auth_context_.reset(DEBUG_LOCATION, "handshake");
148 connector_.reset(DEBUG_LOCATION, "handshake");
149 }
150
MoveReadBufferIntoHandshakeBuffer()151 size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
152 size_t bytes_in_read_buffer = args_->read_buffer.Length();
153 if (handshake_buffer_size_ < bytes_in_read_buffer) {
154 handshake_buffer_ = static_cast<uint8_t*>(
155 gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
156 handshake_buffer_size_ = bytes_in_read_buffer;
157 }
158 size_t offset = 0;
159 while (args_->read_buffer.Count() > 0) {
160 Slice slice = args_->read_buffer.TakeFirst();
161 memcpy(handshake_buffer_ + offset, slice.data(), slice.size());
162 offset += slice.size();
163 }
164 return bytes_in_read_buffer;
165 }
166
167 // If the handshake failed or we're shutting down, clean up and invoke the
168 // callback with the error.
HandshakeFailedLocked(absl::Status error)169 void SecurityHandshaker::HandshakeFailedLocked(absl::Status error) {
170 if (error.ok()) {
171 // If we were shut down after the handshake succeeded but before an
172 // endpoint callback was invoked, we need to generate our own error.
173 error = GRPC_ERROR_CREATE("Handshaker shutdown");
174 }
175 if (!is_shutdown_) {
176 tsi_handshaker_shutdown(handshaker_);
177 // Set shutdown to true so that subsequent calls to
178 // security_handshaker_shutdown() do nothing.
179 is_shutdown_ = true;
180 }
181 // Invoke callback.
182 Finish(std::move(error));
183 }
184
Finish(absl::Status status)185 void SecurityHandshaker::Finish(absl::Status status) {
186 InvokeOnHandshakeDone(args_, std::move(on_handshake_done_),
187 std::move(status));
188 }
189
190 namespace {
191
192 RefCountedPtr<channelz::SocketNode::Security>
MakeChannelzSecurityFromAuthContext(grpc_auth_context * auth_context)193 MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) {
194 RefCountedPtr<channelz::SocketNode::Security> security =
195 MakeRefCounted<channelz::SocketNode::Security>();
196 // TODO(yashykt): Currently, we are assuming TLS by default and are only able
197 // to fill in the remote certificate but we should ideally be able to fill in
198 // other fields in
199 // https://github.com/grpc/grpc/blob/fcd43e90304862a823316b224ee733d17a8cfd90/src/proto/grpc/channelz/channelz.proto#L326
200 // from grpc_auth_context.
201 security->type = channelz::SocketNode::Security::ModelType::kTls;
202 security->tls = absl::make_optional<channelz::SocketNode::Security::Tls>();
203 grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
204 auth_context, GRPC_X509_PEM_CERT_PROPERTY_NAME);
205 const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
206 if (prop != nullptr) {
207 security->tls->remote_certificate =
208 std::string(prop->value, prop->value_length);
209 }
210 return security;
211 }
212
213 } // namespace
214
OnPeerCheckedFn(grpc_error_handle error)215 void SecurityHandshaker::OnPeerCheckedFn(grpc_error_handle error) {
216 MutexLock lock(&mu_);
217 on_peer_checked_ = nullptr;
218 if (!error.ok() || is_shutdown_) {
219 HandshakeFailedLocked(error);
220 return;
221 }
222 // Get unused bytes.
223 const unsigned char* unused_bytes = nullptr;
224 size_t unused_bytes_size = 0;
225 tsi_result result = tsi_handshaker_result_get_unused_bytes(
226 handshaker_result_, &unused_bytes, &unused_bytes_size);
227 if (result != TSI_OK) {
228 HandshakeFailedLocked(GRPC_ERROR_CREATE(
229 absl::StrCat("TSI handshaker result does not provide unused bytes (",
230 tsi_result_to_string(result), ")")));
231 return;
232 }
233 // Check whether we need to wrap the endpoint.
234 tsi_frame_protector_type frame_protector_type;
235 result = tsi_handshaker_result_get_frame_protector_type(
236 handshaker_result_, &frame_protector_type);
237 if (result != TSI_OK) {
238 HandshakeFailedLocked(GRPC_ERROR_CREATE(
239 absl::StrCat("TSI handshaker result does not implement "
240 "get_frame_protector_type (",
241 tsi_result_to_string(result), ")")));
242 return;
243 }
244 tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
245 tsi_frame_protector* protector = nullptr;
246 switch (frame_protector_type) {
247 case TSI_FRAME_PROTECTOR_ZERO_COPY:
248 ABSL_FALLTHROUGH_INTENDED;
249 case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY:
250 // Create zero-copy frame protector.
251 result = tsi_handshaker_result_create_zero_copy_grpc_protector(
252 handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
253 &zero_copy_protector);
254 if (result != TSI_OK) {
255 HandshakeFailedLocked(GRPC_ERROR_CREATE(
256 absl::StrCat("Zero-copy frame protector creation failed (",
257 tsi_result_to_string(result), ")")));
258 return;
259 }
260 break;
261 case TSI_FRAME_PROTECTOR_NORMAL:
262 // Create normal frame protector.
263 result = tsi_handshaker_result_create_frame_protector(
264 handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
265 &protector);
266 if (result != TSI_OK) {
267 HandshakeFailedLocked(
268 GRPC_ERROR_CREATE(absl::StrCat("Frame protector creation failed (",
269 tsi_result_to_string(result), ")")));
270 return;
271 }
272 break;
273 case TSI_FRAME_PROTECTOR_NONE:
274 break;
275 }
276 bool has_frame_protector =
277 zero_copy_protector != nullptr || protector != nullptr;
278 // If we have a frame protector, create a secure endpoint.
279 if (has_frame_protector) {
280 if (unused_bytes_size > 0) {
281 grpc_slice slice = grpc_slice_from_copied_buffer(
282 reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
283 args_->endpoint = grpc_secure_endpoint_create(
284 protector, zero_copy_protector, std::move(args_->endpoint), &slice,
285 args_->args.ToC().get(), 1);
286 CSliceUnref(slice);
287 } else {
288 args_->endpoint = grpc_secure_endpoint_create(
289 protector, zero_copy_protector, std::move(args_->endpoint), nullptr,
290 args_->args.ToC().get(), 0);
291 }
292 } else if (unused_bytes_size > 0) {
293 // Not wrapping the endpoint, so just pass along unused bytes.
294 args_->read_buffer.Append(Slice::FromCopiedBuffer(
295 reinterpret_cast<const char*>(unused_bytes), unused_bytes_size));
296 }
297 // Done with handshaker result.
298 tsi_handshaker_result_destroy(handshaker_result_);
299 handshaker_result_ = nullptr;
300 args_->args = args_->args.SetObject(auth_context_);
301 // Add channelz channel args only if frame protector is created.
302 if (has_frame_protector) {
303 args_->args = args_->args.SetObject(
304 MakeChannelzSecurityFromAuthContext(auth_context_.get()));
305 }
306 // Set shutdown to true so that subsequent calls to
307 // security_handshaker_shutdown() do nothing.
308 is_shutdown_ = true;
309 // Invoke callback.
310 Finish(absl::OkStatus());
311 }
312
CheckPeerLocked()313 grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
314 tsi_peer peer;
315 tsi_result result =
316 tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
317 if (result != TSI_OK) {
318 return GRPC_ERROR_CREATE(absl::StrCat("Peer extraction failed (",
319 tsi_result_to_string(result), ")"));
320 }
321 on_peer_checked_ = NewClosure(
322 [self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
323 self->OnPeerCheckedFn(std::move(status));
324 });
325 connector_->check_peer(peer, args_->endpoint.get(), args_->args,
326 &auth_context_, on_peer_checked_);
327 grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
328 auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
329 const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
330 if (!prop ||
331 !strcmp(tsi_security_level_to_string(TSI_SECURITY_NONE), prop->value)) {
332 global_stats().IncrementInsecureConnectionsCreated();
333 }
334 return absl::OkStatus();
335 }
336
OnHandshakeNextDoneLocked(tsi_result result,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)337 grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
338 tsi_result result, const unsigned char* bytes_to_send,
339 size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
340 grpc_error_handle error;
341 // Handshaker was shutdown.
342 if (is_shutdown_) {
343 tsi_handshaker_result_destroy(handshaker_result);
344 return GRPC_ERROR_CREATE("Handshaker shutdown");
345 }
346 // Read more if we need to.
347 if (result == TSI_INCOMPLETE_DATA) {
348 CHECK_EQ(bytes_to_send_size, 0u);
349 grpc_endpoint_read(
350 args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
351 NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
352 absl::Status status) {
353 self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
354 }),
355 /*urgent=*/true, /*min_progress_size=*/1);
356 return error;
357 }
358 if (result != TSI_OK) {
359 // TODO(roth): Get a better signal from the TSI layer as to what
360 // status code we should use here.
361 return GRPC_ERROR_CREATE(absl::StrCat(
362 connector_->type().name(), " handshake failed (",
363 tsi_result_to_string(result), ")",
364 (tsi_handshake_error_.empty() ? "" : ": "), tsi_handshake_error_));
365 }
366 // Update handshaker result.
367 if (handshaker_result != nullptr) {
368 CHECK_EQ(handshaker_result_, nullptr);
369 handshaker_result_ = handshaker_result;
370 }
371 if (bytes_to_send_size > 0) {
372 // Send data to peer, if needed.
373 outgoing_.Clear();
374 outgoing_.Append(Slice::FromCopiedBuffer(
375 reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size));
376 grpc_endpoint_write(
377 args_->endpoint.get(), outgoing_.c_slice_buffer(),
378 NewClosure(
379 [self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
380 self->OnHandshakeDataSentToPeerFnScheduler(std::move(status));
381 }),
382 nullptr, /*max_frame_size=*/INT_MAX);
383 } else if (handshaker_result == nullptr) {
384 // There is nothing to send, but need to read from peer.
385 grpc_endpoint_read(
386 args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
387 NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
388 absl::Status status) {
389 self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
390 }),
391 /*urgent=*/true, /*min_progress_size=*/1);
392 } else {
393 // Handshake has finished, check peer and so on.
394 error = CheckPeerLocked();
395 }
396 return error;
397 }
398
OnHandshakeNextDoneGrpcWrapper(tsi_result result,void * user_data,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)399 void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
400 tsi_result result, void* user_data, const unsigned char* bytes_to_send,
401 size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
402 RefCountedPtr<SecurityHandshaker> h(
403 static_cast<SecurityHandshaker*>(user_data));
404 MutexLock lock(&h->mu_);
405 grpc_error_handle error = h->OnHandshakeNextDoneLocked(
406 result, bytes_to_send, bytes_to_send_size, handshaker_result);
407 if (!error.ok()) {
408 h->HandshakeFailedLocked(std::move(error));
409 }
410 }
411
DoHandshakerNextLocked(const unsigned char * bytes_received,size_t bytes_received_size)412 grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
413 const unsigned char* bytes_received, size_t bytes_received_size) {
414 // Invoke TSI handshaker.
415 const unsigned char* bytes_to_send = nullptr;
416 size_t bytes_to_send_size = 0;
417 tsi_handshaker_result* hs_result = nullptr;
418 auto self = RefAsSubclass<SecurityHandshaker>();
419 tsi_result result = tsi_handshaker_next(
420 handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
421 &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper,
422 self.get(), &tsi_handshake_error_);
423 if (result == TSI_ASYNC) {
424 // Handshaker operating asynchronously. Callback will be invoked in a TSI
425 // thread. We no longer own the ref held in self.
426 self.release();
427 return absl::OkStatus();
428 }
429 // Handshaker returned synchronously. Invoke callback directly in
430 // this thread with our existing exec_ctx.
431 return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
432 hs_result);
433 }
434
435 // This callback might be run inline while we are still holding on to the mutex,
436 // so run OnHandshakeDataReceivedFromPeerFn asynchronously to avoid a deadlock.
437 // TODO(roth): This will no longer be necessary once we migrate to the
438 // EventEngine endpoint API.
OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error)439 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
440 grpc_error_handle error) {
441 args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
442 error = std::move(error)]() mutable {
443 ApplicationCallbackExecCtx callback_exec_ctx;
444 ExecCtx exec_ctx;
445 self->OnHandshakeDataReceivedFromPeerFn(std::move(error));
446 // Avoid destruction outside of an ExecCtx (since this is non-cancelable).
447 self.reset();
448 });
449 }
450
OnHandshakeDataReceivedFromPeerFn(absl::Status error)451 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
452 MutexLock lock(&mu_);
453 if (!error.ok() || is_shutdown_) {
454 HandshakeFailedLocked(
455 GRPC_ERROR_CREATE_REFERENCING("Handshake read failed", &error, 1));
456 return;
457 }
458 // Copy all slices received.
459 size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
460 // Call TSI handshaker.
461 error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
462 if (!error.ok()) {
463 HandshakeFailedLocked(std::move(error));
464 }
465 }
466
467 // This callback might be run inline while we are still holding on to the mutex,
468 // so run OnHandshakeDataSentToPeerFn asynchronously to avoid a deadlock.
469 // TODO(roth): This will no longer be necessary once we migrate to the
470 // EventEngine endpoint API.
OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error)471 void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
472 grpc_error_handle error) {
473 args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
474 error = std::move(error)]() mutable {
475 ApplicationCallbackExecCtx callback_exec_ctx;
476 ExecCtx exec_ctx;
477 self->OnHandshakeDataSentToPeerFn(std::move(error));
478 // Avoid destruction outside of an ExecCtx (since this is non-cancelable).
479 self.reset();
480 });
481 }
482
OnHandshakeDataSentToPeerFn(absl::Status error)483 void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
484 MutexLock lock(&mu_);
485 if (!error.ok() || is_shutdown_) {
486 HandshakeFailedLocked(
487 GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1));
488 return;
489 }
490 // We may be done.
491 if (handshaker_result_ == nullptr) {
492 grpc_endpoint_read(
493 args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
494 NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
495 absl::Status status) {
496 self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
497 }),
498 /*urgent=*/true, /*min_progress_size=*/1);
499 } else {
500 error = CheckPeerLocked();
501 if (!error.ok()) {
502 HandshakeFailedLocked(error);
503 return;
504 }
505 }
506 }
507
508 //
509 // public handshaker API
510 //
511
Shutdown(grpc_error_handle error)512 void SecurityHandshaker::Shutdown(grpc_error_handle error) {
513 MutexLock lock(&mu_);
514 if (!is_shutdown_) {
515 is_shutdown_ = true;
516 connector_->cancel_check_peer(on_peer_checked_, std::move(error));
517 tsi_handshaker_shutdown(handshaker_);
518 args_->endpoint.reset();
519 }
520 }
521
DoHandshake(HandshakerArgs * args,absl::AnyInvocable<void (absl::Status)> on_handshake_done)522 void SecurityHandshaker::DoHandshake(
523 HandshakerArgs* args,
524 absl::AnyInvocable<void(absl::Status)> on_handshake_done) {
525 MutexLock lock(&mu_);
526 args_ = args;
527 on_handshake_done_ = std::move(on_handshake_done);
528 size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
529 grpc_error_handle error =
530 DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
531 if (!error.ok()) {
532 HandshakeFailedLocked(error);
533 }
534 }
535
536 //
537 // FailHandshaker
538 //
539
540 class FailHandshaker : public Handshaker {
541 public:
FailHandshaker(absl::Status status)542 explicit FailHandshaker(absl::Status status) : status_(std::move(status)) {}
name() const543 absl::string_view name() const override { return "security_fail"; }
DoHandshake(HandshakerArgs * args,absl::AnyInvocable<void (absl::Status)> on_handshake_done)544 void DoHandshake(
545 HandshakerArgs* args,
546 absl::AnyInvocable<void(absl::Status)> on_handshake_done) override {
547 InvokeOnHandshakeDone(args, std::move(on_handshake_done), status_);
548 }
Shutdown(absl::Status)549 void Shutdown(absl::Status /*error*/) override {}
550
551 private:
552 ~FailHandshaker() override = default;
553 absl::Status status_;
554 };
555
556 //
557 // handshaker factories
558 //
559
560 class ClientSecurityHandshakerFactory : public HandshakerFactory {
561 public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)562 void AddHandshakers(const ChannelArgs& args,
563 grpc_pollset_set* interested_parties,
564 HandshakeManager* handshake_mgr) override {
565 auto* security_connector =
566 args.GetObject<grpc_channel_security_connector>();
567 if (security_connector) {
568 security_connector->add_handshakers(args, interested_parties,
569 handshake_mgr);
570 }
571 }
Priority()572 HandshakerPriority Priority() override {
573 return HandshakerPriority::kSecurityHandshakers;
574 }
575 ~ClientSecurityHandshakerFactory() override = default;
576 };
577
578 class ServerSecurityHandshakerFactory : public HandshakerFactory {
579 public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)580 void AddHandshakers(const ChannelArgs& args,
581 grpc_pollset_set* interested_parties,
582 HandshakeManager* handshake_mgr) override {
583 auto* security_connector = args.GetObject<grpc_server_security_connector>();
584 if (security_connector) {
585 security_connector->add_handshakers(args, interested_parties,
586 handshake_mgr);
587 }
588 }
Priority()589 HandshakerPriority Priority() override {
590 return HandshakerPriority::kSecurityHandshakers;
591 }
592 ~ServerSecurityHandshakerFactory() override = default;
593 };
594
595 } // namespace
596
597 //
598 // exported functions
599 //
600
SecurityHandshakerCreate(absl::StatusOr<tsi_handshaker * > handshaker,grpc_security_connector * connector,const ChannelArgs & args)601 RefCountedPtr<Handshaker> SecurityHandshakerCreate(
602 absl::StatusOr<tsi_handshaker*> handshaker,
603 grpc_security_connector* connector, const ChannelArgs& args) {
604 // If no TSI handshaker was created, return a handshaker that always fails.
605 // Otherwise, return a real security handshaker.
606 if (!handshaker.ok()) {
607 return MakeRefCounted<FailHandshaker>(
608 absl::Status(handshaker.status().code(),
609 absl::StrCat("Failed to create security handshaker: ",
610 handshaker.status().message())));
611 } else if (*handshaker == nullptr) {
612 // TODO(gtcooke94) Once all TSI impls are updated to pass StatusOr<> instead
613 // of null, we should change this to use absl::InternalError().
614 return MakeRefCounted<FailHandshaker>(
615 absl::UnknownError("Failed to create security handshaker."));
616 } else {
617 return MakeRefCounted<SecurityHandshaker>(*handshaker, connector, args);
618 }
619 }
620
SecurityRegisterHandshakerFactories(CoreConfiguration::Builder * builder)621 void SecurityRegisterHandshakerFactories(CoreConfiguration::Builder* builder) {
622 builder->handshaker_registry()->RegisterHandshakerFactory(
623 HANDSHAKER_CLIENT, std::make_unique<ClientSecurityHandshakerFactory>());
624 builder->handshaker_registry()->RegisterHandshakerFactory(
625 HANDSHAKER_SERVER, std::make_unique<ServerSecurityHandshakerFactory>());
626 }
627
628 } // namespace grpc_core
629