• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/core/handshaker/security/security_handshaker.h"
20 
21 #include <grpc/grpc_security.h>
22 #include <grpc/grpc_security_constants.h>
23 #include <grpc/impl/channel_arg_names.h>
24 #include <grpc/slice.h>
25 #include <grpc/slice_buffer.h>
26 #include <grpc/support/alloc.h>
27 #include <grpc/support/port_platform.h>
28 #include <limits.h>
29 #include <stdint.h>
30 #include <string.h>
31 
32 #include <algorithm>
33 #include <memory>
34 #include <string>
35 #include <utility>
36 
37 #include "absl/base/attributes.h"
38 #include "absl/functional/any_invocable.h"
39 #include "absl/log/check.h"
40 #include "absl/status/status.h"
41 #include "absl/strings/str_cat.h"
42 #include "absl/strings/string_view.h"
43 #include "absl/types/optional.h"
44 #include "src/core/channelz/channelz.h"
45 #include "src/core/config/core_configuration.h"
46 #include "src/core/handshaker/handshaker.h"
47 #include "src/core/handshaker/handshaker_factory.h"
48 #include "src/core/handshaker/handshaker_registry.h"
49 #include "src/core/handshaker/security/secure_endpoint.h"
50 #include "src/core/lib/channel/channel_args.h"
51 #include "src/core/lib/iomgr/closure.h"
52 #include "src/core/lib/iomgr/endpoint.h"
53 #include "src/core/lib/iomgr/error.h"
54 #include "src/core/lib/iomgr/exec_ctx.h"
55 #include "src/core/lib/iomgr/iomgr_fwd.h"
56 #include "src/core/lib/iomgr/tcp_server.h"
57 #include "src/core/lib/security/context/security_context.h"
58 #include "src/core/lib/slice/slice.h"
59 #include "src/core/lib/slice/slice_internal.h"
60 #include "src/core/telemetry/stats.h"
61 #include "src/core/telemetry/stats_data.h"
62 #include "src/core/tsi/transport_security_grpc.h"
63 #include "src/core/util/debug_location.h"
64 #include "src/core/util/ref_counted_ptr.h"
65 #include "src/core/util/sync.h"
66 #include "src/core/util/unique_type_name.h"
67 
68 #define GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE 256
69 
70 namespace grpc_core {
71 
72 namespace {
73 
74 class SecurityHandshaker : public Handshaker {
75  public:
76   SecurityHandshaker(tsi_handshaker* handshaker,
77                      grpc_security_connector* connector,
78                      const ChannelArgs& args);
79   ~SecurityHandshaker() override;
name() const80   absl::string_view name() const override { return "security"; }
81   void DoHandshake(
82       HandshakerArgs* args,
83       absl::AnyInvocable<void(absl::Status)> on_handshake_done) override;
84   void Shutdown(absl::Status error) override;
85 
86  private:
87   grpc_error_handle DoHandshakerNextLocked(const unsigned char* bytes_received,
88                                            size_t bytes_received_size)
89       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
90 
91   grpc_error_handle OnHandshakeNextDoneLocked(
92       tsi_result result, const unsigned char* bytes_to_send,
93       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result)
94       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
95   void HandshakeFailedLocked(absl::Status error)
96       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
97   void Finish(absl::Status status);
98 
99   void OnHandshakeDataReceivedFromPeerFn(absl::Status error);
100   void OnHandshakeDataSentToPeerFn(absl::Status error);
101   void OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error);
102   void OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error);
103   static void OnHandshakeNextDoneGrpcWrapper(
104       tsi_result result, void* user_data, const unsigned char* bytes_to_send,
105       size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result);
106   void OnPeerCheckedFn(grpc_error_handle error);
107   size_t MoveReadBufferIntoHandshakeBuffer();
108   grpc_error_handle CheckPeerLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
109 
110   // State set at creation time.
111   tsi_handshaker* handshaker_;
112   RefCountedPtr<grpc_security_connector> connector_;
113 
114   Mutex mu_;
115 
116   bool is_shutdown_ = false;
117 
118   // State saved while performing the handshake.
119   HandshakerArgs* args_ = nullptr;
120   absl::AnyInvocable<void(absl::Status)> on_handshake_done_;
121 
122   size_t handshake_buffer_size_;
123   unsigned char* handshake_buffer_;
124   SliceBuffer outgoing_;
125   RefCountedPtr<grpc_auth_context> auth_context_;
126   tsi_handshaker_result* handshaker_result_ = nullptr;
127   size_t max_frame_size_ = 0;
128   std::string tsi_handshake_error_;
129   grpc_closure* on_peer_checked_ ABSL_GUARDED_BY(mu_) = nullptr;
130 };
131 
SecurityHandshaker(tsi_handshaker * handshaker,grpc_security_connector * connector,const ChannelArgs & args)132 SecurityHandshaker::SecurityHandshaker(tsi_handshaker* handshaker,
133                                        grpc_security_connector* connector,
134                                        const ChannelArgs& args)
135     : handshaker_(handshaker),
136       connector_(connector->Ref(DEBUG_LOCATION, "handshake")),
137       handshake_buffer_size_(GRPC_INITIAL_HANDSHAKE_BUFFER_SIZE),
138       handshake_buffer_(
139           static_cast<uint8_t*>(gpr_malloc(handshake_buffer_size_))),
140       max_frame_size_(
141           std::max(0, args.GetInt(GRPC_ARG_TSI_MAX_FRAME_SIZE).value_or(0))) {}
142 
~SecurityHandshaker()143 SecurityHandshaker::~SecurityHandshaker() {
144   tsi_handshaker_destroy(handshaker_);
145   tsi_handshaker_result_destroy(handshaker_result_);
146   gpr_free(handshake_buffer_);
147   auth_context_.reset(DEBUG_LOCATION, "handshake");
148   connector_.reset(DEBUG_LOCATION, "handshake");
149 }
150 
MoveReadBufferIntoHandshakeBuffer()151 size_t SecurityHandshaker::MoveReadBufferIntoHandshakeBuffer() {
152   size_t bytes_in_read_buffer = args_->read_buffer.Length();
153   if (handshake_buffer_size_ < bytes_in_read_buffer) {
154     handshake_buffer_ = static_cast<uint8_t*>(
155         gpr_realloc(handshake_buffer_, bytes_in_read_buffer));
156     handshake_buffer_size_ = bytes_in_read_buffer;
157   }
158   size_t offset = 0;
159   while (args_->read_buffer.Count() > 0) {
160     Slice slice = args_->read_buffer.TakeFirst();
161     memcpy(handshake_buffer_ + offset, slice.data(), slice.size());
162     offset += slice.size();
163   }
164   return bytes_in_read_buffer;
165 }
166 
167 // If the handshake failed or we're shutting down, clean up and invoke the
168 // callback with the error.
HandshakeFailedLocked(absl::Status error)169 void SecurityHandshaker::HandshakeFailedLocked(absl::Status error) {
170   if (error.ok()) {
171     // If we were shut down after the handshake succeeded but before an
172     // endpoint callback was invoked, we need to generate our own error.
173     error = GRPC_ERROR_CREATE("Handshaker shutdown");
174   }
175   if (!is_shutdown_) {
176     tsi_handshaker_shutdown(handshaker_);
177     // Set shutdown to true so that subsequent calls to
178     // security_handshaker_shutdown() do nothing.
179     is_shutdown_ = true;
180   }
181   // Invoke callback.
182   Finish(std::move(error));
183 }
184 
Finish(absl::Status status)185 void SecurityHandshaker::Finish(absl::Status status) {
186   InvokeOnHandshakeDone(args_, std::move(on_handshake_done_),
187                         std::move(status));
188 }
189 
190 namespace {
191 
192 RefCountedPtr<channelz::SocketNode::Security>
MakeChannelzSecurityFromAuthContext(grpc_auth_context * auth_context)193 MakeChannelzSecurityFromAuthContext(grpc_auth_context* auth_context) {
194   RefCountedPtr<channelz::SocketNode::Security> security =
195       MakeRefCounted<channelz::SocketNode::Security>();
196   // TODO(yashykt): Currently, we are assuming TLS by default and are only able
197   // to fill in the remote certificate but we should ideally be able to fill in
198   // other fields in
199   // https://github.com/grpc/grpc/blob/fcd43e90304862a823316b224ee733d17a8cfd90/src/proto/grpc/channelz/channelz.proto#L326
200   // from grpc_auth_context.
201   security->type = channelz::SocketNode::Security::ModelType::kTls;
202   security->tls = absl::make_optional<channelz::SocketNode::Security::Tls>();
203   grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
204       auth_context, GRPC_X509_PEM_CERT_PROPERTY_NAME);
205   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
206   if (prop != nullptr) {
207     security->tls->remote_certificate =
208         std::string(prop->value, prop->value_length);
209   }
210   return security;
211 }
212 
213 }  // namespace
214 
OnPeerCheckedFn(grpc_error_handle error)215 void SecurityHandshaker::OnPeerCheckedFn(grpc_error_handle error) {
216   MutexLock lock(&mu_);
217   on_peer_checked_ = nullptr;
218   if (!error.ok() || is_shutdown_) {
219     HandshakeFailedLocked(error);
220     return;
221   }
222   // Get unused bytes.
223   const unsigned char* unused_bytes = nullptr;
224   size_t unused_bytes_size = 0;
225   tsi_result result = tsi_handshaker_result_get_unused_bytes(
226       handshaker_result_, &unused_bytes, &unused_bytes_size);
227   if (result != TSI_OK) {
228     HandshakeFailedLocked(GRPC_ERROR_CREATE(
229         absl::StrCat("TSI handshaker result does not provide unused bytes (",
230                      tsi_result_to_string(result), ")")));
231     return;
232   }
233   // Check whether we need to wrap the endpoint.
234   tsi_frame_protector_type frame_protector_type;
235   result = tsi_handshaker_result_get_frame_protector_type(
236       handshaker_result_, &frame_protector_type);
237   if (result != TSI_OK) {
238     HandshakeFailedLocked(GRPC_ERROR_CREATE(
239         absl::StrCat("TSI handshaker result does not implement "
240                      "get_frame_protector_type (",
241                      tsi_result_to_string(result), ")")));
242     return;
243   }
244   tsi_zero_copy_grpc_protector* zero_copy_protector = nullptr;
245   tsi_frame_protector* protector = nullptr;
246   switch (frame_protector_type) {
247     case TSI_FRAME_PROTECTOR_ZERO_COPY:
248       ABSL_FALLTHROUGH_INTENDED;
249     case TSI_FRAME_PROTECTOR_NORMAL_OR_ZERO_COPY:
250       // Create zero-copy frame protector.
251       result = tsi_handshaker_result_create_zero_copy_grpc_protector(
252           handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
253           &zero_copy_protector);
254       if (result != TSI_OK) {
255         HandshakeFailedLocked(GRPC_ERROR_CREATE(
256             absl::StrCat("Zero-copy frame protector creation failed (",
257                          tsi_result_to_string(result), ")")));
258         return;
259       }
260       break;
261     case TSI_FRAME_PROTECTOR_NORMAL:
262       // Create normal frame protector.
263       result = tsi_handshaker_result_create_frame_protector(
264           handshaker_result_, max_frame_size_ == 0 ? nullptr : &max_frame_size_,
265           &protector);
266       if (result != TSI_OK) {
267         HandshakeFailedLocked(
268             GRPC_ERROR_CREATE(absl::StrCat("Frame protector creation failed (",
269                                            tsi_result_to_string(result), ")")));
270         return;
271       }
272       break;
273     case TSI_FRAME_PROTECTOR_NONE:
274       break;
275   }
276   bool has_frame_protector =
277       zero_copy_protector != nullptr || protector != nullptr;
278   // If we have a frame protector, create a secure endpoint.
279   if (has_frame_protector) {
280     if (unused_bytes_size > 0) {
281       grpc_slice slice = grpc_slice_from_copied_buffer(
282           reinterpret_cast<const char*>(unused_bytes), unused_bytes_size);
283       args_->endpoint = grpc_secure_endpoint_create(
284           protector, zero_copy_protector, std::move(args_->endpoint), &slice,
285           args_->args.ToC().get(), 1);
286       CSliceUnref(slice);
287     } else {
288       args_->endpoint = grpc_secure_endpoint_create(
289           protector, zero_copy_protector, std::move(args_->endpoint), nullptr,
290           args_->args.ToC().get(), 0);
291     }
292   } else if (unused_bytes_size > 0) {
293     // Not wrapping the endpoint, so just pass along unused bytes.
294     args_->read_buffer.Append(Slice::FromCopiedBuffer(
295         reinterpret_cast<const char*>(unused_bytes), unused_bytes_size));
296   }
297   // Done with handshaker result.
298   tsi_handshaker_result_destroy(handshaker_result_);
299   handshaker_result_ = nullptr;
300   args_->args = args_->args.SetObject(auth_context_);
301   // Add channelz channel args only if frame protector is created.
302   if (has_frame_protector) {
303     args_->args = args_->args.SetObject(
304         MakeChannelzSecurityFromAuthContext(auth_context_.get()));
305   }
306   // Set shutdown to true so that subsequent calls to
307   // security_handshaker_shutdown() do nothing.
308   is_shutdown_ = true;
309   // Invoke callback.
310   Finish(absl::OkStatus());
311 }
312 
CheckPeerLocked()313 grpc_error_handle SecurityHandshaker::CheckPeerLocked() {
314   tsi_peer peer;
315   tsi_result result =
316       tsi_handshaker_result_extract_peer(handshaker_result_, &peer);
317   if (result != TSI_OK) {
318     return GRPC_ERROR_CREATE(absl::StrCat("Peer extraction failed (",
319                                           tsi_result_to_string(result), ")"));
320   }
321   on_peer_checked_ = NewClosure(
322       [self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
323         self->OnPeerCheckedFn(std::move(status));
324       });
325   connector_->check_peer(peer, args_->endpoint.get(), args_->args,
326                          &auth_context_, on_peer_checked_);
327   grpc_auth_property_iterator it = grpc_auth_context_find_properties_by_name(
328       auth_context_.get(), GRPC_TRANSPORT_SECURITY_LEVEL_PROPERTY_NAME);
329   const grpc_auth_property* prop = grpc_auth_property_iterator_next(&it);
330   if (!prop ||
331       !strcmp(tsi_security_level_to_string(TSI_SECURITY_NONE), prop->value)) {
332     global_stats().IncrementInsecureConnectionsCreated();
333   }
334   return absl::OkStatus();
335 }
336 
OnHandshakeNextDoneLocked(tsi_result result,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)337 grpc_error_handle SecurityHandshaker::OnHandshakeNextDoneLocked(
338     tsi_result result, const unsigned char* bytes_to_send,
339     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
340   grpc_error_handle error;
341   // Handshaker was shutdown.
342   if (is_shutdown_) {
343     tsi_handshaker_result_destroy(handshaker_result);
344     return GRPC_ERROR_CREATE("Handshaker shutdown");
345   }
346   // Read more if we need to.
347   if (result == TSI_INCOMPLETE_DATA) {
348     CHECK_EQ(bytes_to_send_size, 0u);
349     grpc_endpoint_read(
350         args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
351         NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
352                        absl::Status status) {
353           self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
354         }),
355         /*urgent=*/true, /*min_progress_size=*/1);
356     return error;
357   }
358   if (result != TSI_OK) {
359     // TODO(roth): Get a better signal from the TSI layer as to what
360     // status code we should use here.
361     return GRPC_ERROR_CREATE(absl::StrCat(
362         connector_->type().name(), " handshake failed (",
363         tsi_result_to_string(result), ")",
364         (tsi_handshake_error_.empty() ? "" : ": "), tsi_handshake_error_));
365   }
366   // Update handshaker result.
367   if (handshaker_result != nullptr) {
368     CHECK_EQ(handshaker_result_, nullptr);
369     handshaker_result_ = handshaker_result;
370   }
371   if (bytes_to_send_size > 0) {
372     // Send data to peer, if needed.
373     outgoing_.Clear();
374     outgoing_.Append(Slice::FromCopiedBuffer(
375         reinterpret_cast<const char*>(bytes_to_send), bytes_to_send_size));
376     grpc_endpoint_write(
377         args_->endpoint.get(), outgoing_.c_slice_buffer(),
378         NewClosure(
379             [self = RefAsSubclass<SecurityHandshaker>()](absl::Status status) {
380               self->OnHandshakeDataSentToPeerFnScheduler(std::move(status));
381             }),
382         nullptr, /*max_frame_size=*/INT_MAX);
383   } else if (handshaker_result == nullptr) {
384     // There is nothing to send, but need to read from peer.
385     grpc_endpoint_read(
386         args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
387         NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
388                        absl::Status status) {
389           self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
390         }),
391         /*urgent=*/true, /*min_progress_size=*/1);
392   } else {
393     // Handshake has finished, check peer and so on.
394     error = CheckPeerLocked();
395   }
396   return error;
397 }
398 
OnHandshakeNextDoneGrpcWrapper(tsi_result result,void * user_data,const unsigned char * bytes_to_send,size_t bytes_to_send_size,tsi_handshaker_result * handshaker_result)399 void SecurityHandshaker::OnHandshakeNextDoneGrpcWrapper(
400     tsi_result result, void* user_data, const unsigned char* bytes_to_send,
401     size_t bytes_to_send_size, tsi_handshaker_result* handshaker_result) {
402   RefCountedPtr<SecurityHandshaker> h(
403       static_cast<SecurityHandshaker*>(user_data));
404   MutexLock lock(&h->mu_);
405   grpc_error_handle error = h->OnHandshakeNextDoneLocked(
406       result, bytes_to_send, bytes_to_send_size, handshaker_result);
407   if (!error.ok()) {
408     h->HandshakeFailedLocked(std::move(error));
409   }
410 }
411 
DoHandshakerNextLocked(const unsigned char * bytes_received,size_t bytes_received_size)412 grpc_error_handle SecurityHandshaker::DoHandshakerNextLocked(
413     const unsigned char* bytes_received, size_t bytes_received_size) {
414   // Invoke TSI handshaker.
415   const unsigned char* bytes_to_send = nullptr;
416   size_t bytes_to_send_size = 0;
417   tsi_handshaker_result* hs_result = nullptr;
418   auto self = RefAsSubclass<SecurityHandshaker>();
419   tsi_result result = tsi_handshaker_next(
420       handshaker_, bytes_received, bytes_received_size, &bytes_to_send,
421       &bytes_to_send_size, &hs_result, &OnHandshakeNextDoneGrpcWrapper,
422       self.get(), &tsi_handshake_error_);
423   if (result == TSI_ASYNC) {
424     // Handshaker operating asynchronously. Callback will be invoked in a TSI
425     // thread. We no longer own the ref held in self.
426     self.release();
427     return absl::OkStatus();
428   }
429   // Handshaker returned synchronously. Invoke callback directly in
430   // this thread with our existing exec_ctx.
431   return OnHandshakeNextDoneLocked(result, bytes_to_send, bytes_to_send_size,
432                                    hs_result);
433 }
434 
435 // This callback might be run inline while we are still holding on to the mutex,
436 // so run OnHandshakeDataReceivedFromPeerFn asynchronously to avoid a deadlock.
437 // TODO(roth): This will no longer be necessary once we migrate to the
438 // EventEngine endpoint API.
OnHandshakeDataReceivedFromPeerFnScheduler(grpc_error_handle error)439 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFnScheduler(
440     grpc_error_handle error) {
441   args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
442                             error = std::move(error)]() mutable {
443     ApplicationCallbackExecCtx callback_exec_ctx;
444     ExecCtx exec_ctx;
445     self->OnHandshakeDataReceivedFromPeerFn(std::move(error));
446     // Avoid destruction outside of an ExecCtx (since this is non-cancelable).
447     self.reset();
448   });
449 }
450 
OnHandshakeDataReceivedFromPeerFn(absl::Status error)451 void SecurityHandshaker::OnHandshakeDataReceivedFromPeerFn(absl::Status error) {
452   MutexLock lock(&mu_);
453   if (!error.ok() || is_shutdown_) {
454     HandshakeFailedLocked(
455         GRPC_ERROR_CREATE_REFERENCING("Handshake read failed", &error, 1));
456     return;
457   }
458   // Copy all slices received.
459   size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
460   // Call TSI handshaker.
461   error = DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
462   if (!error.ok()) {
463     HandshakeFailedLocked(std::move(error));
464   }
465 }
466 
467 // This callback might be run inline while we are still holding on to the mutex,
468 // so run OnHandshakeDataSentToPeerFn asynchronously to avoid a deadlock.
469 // TODO(roth): This will no longer be necessary once we migrate to the
470 // EventEngine endpoint API.
OnHandshakeDataSentToPeerFnScheduler(grpc_error_handle error)471 void SecurityHandshaker::OnHandshakeDataSentToPeerFnScheduler(
472     grpc_error_handle error) {
473   args_->event_engine->Run([self = RefAsSubclass<SecurityHandshaker>(),
474                             error = std::move(error)]() mutable {
475     ApplicationCallbackExecCtx callback_exec_ctx;
476     ExecCtx exec_ctx;
477     self->OnHandshakeDataSentToPeerFn(std::move(error));
478     // Avoid destruction outside of an ExecCtx (since this is non-cancelable).
479     self.reset();
480   });
481 }
482 
OnHandshakeDataSentToPeerFn(absl::Status error)483 void SecurityHandshaker::OnHandshakeDataSentToPeerFn(absl::Status error) {
484   MutexLock lock(&mu_);
485   if (!error.ok() || is_shutdown_) {
486     HandshakeFailedLocked(
487         GRPC_ERROR_CREATE_REFERENCING("Handshake write failed", &error, 1));
488     return;
489   }
490   // We may be done.
491   if (handshaker_result_ == nullptr) {
492     grpc_endpoint_read(
493         args_->endpoint.get(), args_->read_buffer.c_slice_buffer(),
494         NewClosure([self = RefAsSubclass<SecurityHandshaker>()](
495                        absl::Status status) {
496           self->OnHandshakeDataReceivedFromPeerFnScheduler(std::move(status));
497         }),
498         /*urgent=*/true, /*min_progress_size=*/1);
499   } else {
500     error = CheckPeerLocked();
501     if (!error.ok()) {
502       HandshakeFailedLocked(error);
503       return;
504     }
505   }
506 }
507 
508 //
509 // public handshaker API
510 //
511 
Shutdown(grpc_error_handle error)512 void SecurityHandshaker::Shutdown(grpc_error_handle error) {
513   MutexLock lock(&mu_);
514   if (!is_shutdown_) {
515     is_shutdown_ = true;
516     connector_->cancel_check_peer(on_peer_checked_, std::move(error));
517     tsi_handshaker_shutdown(handshaker_);
518     args_->endpoint.reset();
519   }
520 }
521 
DoHandshake(HandshakerArgs * args,absl::AnyInvocable<void (absl::Status)> on_handshake_done)522 void SecurityHandshaker::DoHandshake(
523     HandshakerArgs* args,
524     absl::AnyInvocable<void(absl::Status)> on_handshake_done) {
525   MutexLock lock(&mu_);
526   args_ = args;
527   on_handshake_done_ = std::move(on_handshake_done);
528   size_t bytes_received_size = MoveReadBufferIntoHandshakeBuffer();
529   grpc_error_handle error =
530       DoHandshakerNextLocked(handshake_buffer_, bytes_received_size);
531   if (!error.ok()) {
532     HandshakeFailedLocked(error);
533   }
534 }
535 
536 //
537 // FailHandshaker
538 //
539 
540 class FailHandshaker : public Handshaker {
541  public:
FailHandshaker(absl::Status status)542   explicit FailHandshaker(absl::Status status) : status_(std::move(status)) {}
name() const543   absl::string_view name() const override { return "security_fail"; }
DoHandshake(HandshakerArgs * args,absl::AnyInvocable<void (absl::Status)> on_handshake_done)544   void DoHandshake(
545       HandshakerArgs* args,
546       absl::AnyInvocable<void(absl::Status)> on_handshake_done) override {
547     InvokeOnHandshakeDone(args, std::move(on_handshake_done), status_);
548   }
Shutdown(absl::Status)549   void Shutdown(absl::Status /*error*/) override {}
550 
551  private:
552   ~FailHandshaker() override = default;
553   absl::Status status_;
554 };
555 
556 //
557 // handshaker factories
558 //
559 
560 class ClientSecurityHandshakerFactory : public HandshakerFactory {
561  public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)562   void AddHandshakers(const ChannelArgs& args,
563                       grpc_pollset_set* interested_parties,
564                       HandshakeManager* handshake_mgr) override {
565     auto* security_connector =
566         args.GetObject<grpc_channel_security_connector>();
567     if (security_connector) {
568       security_connector->add_handshakers(args, interested_parties,
569                                           handshake_mgr);
570     }
571   }
Priority()572   HandshakerPriority Priority() override {
573     return HandshakerPriority::kSecurityHandshakers;
574   }
575   ~ClientSecurityHandshakerFactory() override = default;
576 };
577 
578 class ServerSecurityHandshakerFactory : public HandshakerFactory {
579  public:
AddHandshakers(const ChannelArgs & args,grpc_pollset_set * interested_parties,HandshakeManager * handshake_mgr)580   void AddHandshakers(const ChannelArgs& args,
581                       grpc_pollset_set* interested_parties,
582                       HandshakeManager* handshake_mgr) override {
583     auto* security_connector = args.GetObject<grpc_server_security_connector>();
584     if (security_connector) {
585       security_connector->add_handshakers(args, interested_parties,
586                                           handshake_mgr);
587     }
588   }
Priority()589   HandshakerPriority Priority() override {
590     return HandshakerPriority::kSecurityHandshakers;
591   }
592   ~ServerSecurityHandshakerFactory() override = default;
593 };
594 
595 }  // namespace
596 
597 //
598 // exported functions
599 //
600 
SecurityHandshakerCreate(absl::StatusOr<tsi_handshaker * > handshaker,grpc_security_connector * connector,const ChannelArgs & args)601 RefCountedPtr<Handshaker> SecurityHandshakerCreate(
602     absl::StatusOr<tsi_handshaker*> handshaker,
603     grpc_security_connector* connector, const ChannelArgs& args) {
604   // If no TSI handshaker was created, return a handshaker that always fails.
605   // Otherwise, return a real security handshaker.
606   if (!handshaker.ok()) {
607     return MakeRefCounted<FailHandshaker>(
608         absl::Status(handshaker.status().code(),
609                      absl::StrCat("Failed to create security handshaker: ",
610                                   handshaker.status().message())));
611   } else if (*handshaker == nullptr) {
612     // TODO(gtcooke94) Once all TSI impls are updated to pass StatusOr<> instead
613     // of null, we should change this to use absl::InternalError().
614     return MakeRefCounted<FailHandshaker>(
615         absl::UnknownError("Failed to create security handshaker."));
616   } else {
617     return MakeRefCounted<SecurityHandshaker>(*handshaker, connector, args);
618   }
619 }
620 
SecurityRegisterHandshakerFactories(CoreConfiguration::Builder * builder)621 void SecurityRegisterHandshakerFactories(CoreConfiguration::Builder* builder) {
622   builder->handshaker_registry()->RegisterHandshakerFactory(
623       HANDSHAKER_CLIENT, std::make_unique<ClientSecurityHandshakerFactory>());
624   builder->handshaker_registry()->RegisterHandshakerFactory(
625       HANDSHAKER_SERVER, std::make_unique<ServerSecurityHandshakerFactory>());
626 }
627 
628 }  // namespace grpc_core
629