• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 //
3 // Copyright 2015 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "src/core/tsi/ssl_transport_security.h"
20 
21 #include <grpc/support/port_platform.h>
22 #include <limits.h>
23 #include <string.h>
24 
25 #include <cstdlib>
26 
27 #include "src/core/lib/surface/init.h"
28 #include "src/core/tsi/transport_security_interface.h"
29 
30 // TODO(jboeuf): refactor inet_ntop into a portability header.
31 // Note: for whomever reads this and tries to refactor this, this
32 // can't be in grpc, it has to be in gpr.
33 #ifdef GPR_WINDOWS
34 #include <ws2tcpip.h>
35 #else
36 #include <arpa/inet.h>
37 #include <sys/socket.h>
38 #endif
39 
40 #include <grpc/grpc_crl_provider.h>
41 #include <grpc/grpc_security.h>
42 #include <grpc/support/alloc.h>
43 #include <grpc/support/string_util.h>
44 #include <grpc/support/sync.h>
45 #include <grpc/support/thd_id.h>
46 #include <openssl/bio.h>
47 #include <openssl/crypto.h>  // For OPENSSL_free
48 #include <openssl/engine.h>
49 #include <openssl/err.h>
50 #include <openssl/ssl.h>
51 #include <openssl/tls1.h>
52 #include <openssl/x509.h>
53 #include <openssl/x509v3.h>
54 
55 #include <memory>
56 #include <string>
57 
58 #include "absl/log/check.h"
59 #include "absl/log/log.h"
60 #include "absl/strings/match.h"
61 #include "absl/strings/str_cat.h"
62 #include "absl/strings/string_view.h"
63 #include "src/core/lib/security/credentials/tls/grpc_tls_crl_provider.h"
64 #include "src/core/tsi/ssl/key_logging/ssl_key_logging.h"
65 #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
66 #include "src/core/tsi/ssl_transport_security_utils.h"
67 #include "src/core/tsi/ssl_types.h"
68 #include "src/core/tsi/transport_security.h"
69 #include "src/core/util/crash.h"
70 #include "src/core/util/useful.h"
71 
72 // --- Constants. ---
73 
74 #define TSI_SSL_MAX_BIO_WRITE_ATTEMPTS 100
75 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
76 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
77 #define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
78 const size_t kMaxChainLength = 100;
79 
80 // Putting a macro like this and littering the source file with #if is really
81 // bad practice.
82 // TODO(jboeuf): refactor all the #if / #endif in a separate module.
83 #ifndef TSI_OPENSSL_ALPN_SUPPORT
84 #define TSI_OPENSSL_ALPN_SUPPORT 1
85 #endif
86 
87 // TODO(jboeuf): I have not found a way to get this number dynamically from the
88 // SSL structure. This is what we would ultimately want though...
89 #define TSI_SSL_MAX_PROTECTION_OVERHEAD 100
90 
91 using TlsSessionKeyLogger = tsi::TlsSessionKeyLoggerCache::TlsSessionKeyLogger;
92 
93 // --- Structure definitions. ---
94 
95 struct tsi_ssl_root_certs_store {
96   X509_STORE* store;
97 };
98 
99 struct tsi_ssl_handshaker_factory {
100   const tsi_ssl_handshaker_factory_vtable* vtable;
101   gpr_refcount refcount;
102 };
103 
104 struct tsi_ssl_client_handshaker_factory {
105   tsi_ssl_handshaker_factory base;
106   SSL_CTX* ssl_context;
107   unsigned char* alpn_protocol_list;
108   size_t alpn_protocol_list_length;
109   grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache;
110   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
111 };
112 
113 struct tsi_ssl_server_handshaker_factory {
114   // Several contexts to support SNI.
115   // The tsi_peer array contains the subject names of the server certificates
116   // associated with the contexts at the same index.
117   tsi_ssl_handshaker_factory base;
118   SSL_CTX** ssl_contexts;
119   tsi_peer* ssl_context_x509_subject_names;
120   size_t ssl_context_count;
121   unsigned char* alpn_protocol_list;
122   size_t alpn_protocol_list_length;
123   grpc_core::RefCountedPtr<TlsSessionKeyLogger> key_logger;
124 };
125 
126 struct tsi_ssl_handshaker {
127   tsi_handshaker base;
128   SSL* ssl;
129   BIO* network_io;
130   tsi_result result;
131   unsigned char* outgoing_bytes_buffer;
132   size_t outgoing_bytes_buffer_size;
133   tsi_ssl_handshaker_factory* factory_ref;
134 };
135 struct tsi_ssl_handshaker_result {
136   tsi_handshaker_result base;
137   SSL* ssl;
138   BIO* network_io;
139   unsigned char* unused_bytes;
140   size_t unused_bytes_size;
141 };
142 struct tsi_ssl_frame_protector {
143   tsi_frame_protector base;
144   SSL* ssl;
145   BIO* network_io;
146   unsigned char* buffer;
147   size_t buffer_size;
148   size_t buffer_offset;
149 };
150 // --- Library Initialization. ---
151 
152 static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
153 static int g_ssl_ctx_ex_factory_index = -1;
154 static int g_ssl_ctx_ex_crl_provider_index = -1;
155 static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
156 static int g_ssl_ex_verified_root_cert_index = -1;
157 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
158 static const char kSslEnginePrefix[] = "engine:";
159 #endif
160 #if OPENSSL_VERSION_NUMBER >= 0x30000000
161 static const int kSslEcCurveNames[] = {NID_X9_62_prime256v1};
162 #endif
163 
164 #if OPENSSL_VERSION_NUMBER < 0x10100000
165 static gpr_mu* g_openssl_mutexes = nullptr;
166 static void openssl_locking_cb(int mode, int type, const char* file,
167                                int line) GRPC_UNUSED;
168 static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
169 
openssl_locking_cb(int mode,int type,const char * file,int line)170 static void openssl_locking_cb(int mode, int type, const char* file, int line) {
171   if (mode & CRYPTO_LOCK) {
172     gpr_mu_lock(&g_openssl_mutexes[type]);
173   } else {
174     gpr_mu_unlock(&g_openssl_mutexes[type]);
175   }
176 }
177 
openssl_thread_id_cb(void)178 static unsigned long openssl_thread_id_cb(void) {
179   return static_cast<unsigned long>(gpr_thd_currentid());
180 }
181 #endif
182 
verified_root_cert_free(void *,void * ptr,CRYPTO_EX_DATA *,int,long,void *)183 static void verified_root_cert_free(void* /*parent*/, void* ptr,
184                                     CRYPTO_EX_DATA* /*ad*/, int /*index*/,
185                                     long /*argl*/, void* /*argp*/) {
186   X509_free(static_cast<X509*>(ptr));
187 }
188 
init_openssl(void)189 static void init_openssl(void) {
190 #if OPENSSL_VERSION_NUMBER >= 0x10100000
191   OPENSSL_init_ssl(0, nullptr);
192   // Ensure OPENSSL global clean up happens after gRPC shutdown completes.
193   // OPENSSL registers an exit handler to clean up global objects, which
194   // otherwise may happen before gRPC removes all references to OPENSSL. Below
195   // exit handler is guaranteed to run after OPENSSL's.
196   std::atexit([]() { grpc_wait_for_shutdown_with_timeout(absl::Seconds(2)); });
197 #else
198   SSL_library_init();
199   SSL_load_error_strings();
200   OpenSSL_add_all_algorithms();
201 #endif
202 #if OPENSSL_VERSION_NUMBER < 0x10100000
203   if (!CRYPTO_get_locking_callback()) {
204     int num_locks = CRYPTO_num_locks();
205     CHECK_GT(num_locks, 0);
206     g_openssl_mutexes = static_cast<gpr_mu*>(
207         gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
208     for (int i = 0; i < num_locks; i++) {
209       gpr_mu_init(&g_openssl_mutexes[i]);
210     }
211     CRYPTO_set_locking_callback(openssl_locking_cb);
212     CRYPTO_set_id_callback(openssl_thread_id_cb);
213   } else {
214     GRPC_TRACE_LOG(tsi, INFO) << "OpenSSL callback has already been set.";
215   }
216 #endif
217   g_ssl_ctx_ex_factory_index =
218       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
219   CHECK_NE(g_ssl_ctx_ex_factory_index, -1);
220 
221   g_ssl_ctx_ex_crl_provider_index =
222       SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
223   CHECK_NE(g_ssl_ctx_ex_crl_provider_index, -1);
224 
225   g_ssl_ex_verified_root_cert_index = SSL_get_ex_new_index(
226       0, nullptr, nullptr, nullptr, verified_root_cert_free);
227   CHECK_NE(g_ssl_ex_verified_root_cert_index, -1);
228 }
229 
230 // --- Ssl utils. ---
231 
232 // TODO(jboeuf): Remove when we are past the debugging phase with this code.
ssl_log_where_info(const SSL * ssl,int where,int flag,const char * msg)233 static void ssl_log_where_info(const SSL* ssl, int where, int flag,
234                                const char* msg) {
235   if ((where & flag) && GRPC_TRACE_FLAG_ENABLED(tsi)) {
236     LOG(INFO) << absl::StrFormat("%20.20s - %s  - %s", msg,
237                                  SSL_state_string_long(ssl),
238                                  SSL_state_string(ssl));
239   }
240 }
241 
242 // Used for debugging. TODO(jboeuf): Remove when code is mature enough.
ssl_info_callback(const SSL * ssl,int where,int ret)243 static void ssl_info_callback(const SSL* ssl, int where, int ret) {
244   if (ret == 0) {
245     LOG(ERROR) << "ssl_info_callback: error occurred.\n";
246     return;
247   }
248 
249   ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP");
250   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START");
251   ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE");
252 }
253 
254 // Returns 1 if name looks like an IP address, 0 otherwise.
255 // This is a very rough heuristic, and only handles IPv6 in hexadecimal form.
looks_like_ip_address(absl::string_view name)256 static int looks_like_ip_address(absl::string_view name) {
257   size_t dot_count = 0;
258   size_t num_size = 0;
259   for (size_t i = 0; i < name.size(); ++i) {
260     if (name[i] == ':') {
261       // IPv6 Address in hexadecimal form, : is not allowed in DNS names.
262       return 1;
263     }
264     if (name[i] >= '0' && name[i] <= '9') {
265       if (num_size > 3) return 0;
266       num_size++;
267     } else if (name[i] == '.') {
268       if (dot_count > 3 || num_size == 0) return 0;
269       dot_count++;
270       num_size = 0;
271     } else {
272       return 0;
273     }
274   }
275   if (dot_count < 3 || num_size == 0) return 0;
276   return 1;
277 }
278 
279 // Gets the subject CN from an X509 cert.
ssl_get_x509_common_name(X509 * cert,unsigned char ** utf8,size_t * utf8_size)280 static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8,
281                                            size_t* utf8_size) {
282   int common_name_index = -1;
283   X509_NAME_ENTRY* common_name_entry = nullptr;
284   ASN1_STRING* common_name_asn1 = nullptr;
285   X509_NAME* subject_name = X509_get_subject_name(cert);
286   int utf8_returned_size = 0;
287   if (subject_name == nullptr) {
288     VLOG(2) << "Could not get subject name from certificate.";
289     return TSI_NOT_FOUND;
290   }
291   common_name_index =
292       X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1);
293   if (common_name_index == -1) {
294     VLOG(2) << "Could not get common name of subject from certificate.";
295     return TSI_NOT_FOUND;
296   }
297   common_name_entry = X509_NAME_get_entry(subject_name, common_name_index);
298   if (common_name_entry == nullptr) {
299     LOG(ERROR) << "Could not get common name entry from certificate.";
300     return TSI_INTERNAL_ERROR;
301   }
302   common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry);
303   if (common_name_asn1 == nullptr) {
304     LOG(ERROR) << "Could not get common name entry asn1 from certificate.";
305     return TSI_INTERNAL_ERROR;
306   }
307   utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1);
308   if (utf8_returned_size < 0) {
309     LOG(ERROR) << "Could not extract utf8 from asn1 string.";
310     return TSI_OUT_OF_RESOURCES;
311   }
312   *utf8_size = static_cast<size_t>(utf8_returned_size);
313   return TSI_OK;
314 }
315 
316 // Gets the subject CN of an X509 cert as a tsi_peer_property.
peer_property_from_x509_common_name(X509 * cert,tsi_peer_property * property)317 static tsi_result peer_property_from_x509_common_name(
318     X509* cert, tsi_peer_property* property) {
319   unsigned char* common_name;
320   size_t common_name_size;
321   tsi_result result =
322       ssl_get_x509_common_name(cert, &common_name, &common_name_size);
323   if (result != TSI_OK) {
324     if (result == TSI_NOT_FOUND) {
325       common_name = nullptr;
326       common_name_size = 0;
327     } else {
328       return result;
329     }
330   }
331   result = tsi_construct_string_peer_property(
332       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY,
333       common_name == nullptr ? "" : reinterpret_cast<const char*>(common_name),
334       common_name_size, property);
335   OPENSSL_free(common_name);
336   return result;
337 }
338 
339 // Gets the subject of an X509 cert as a tsi_peer_property.
peer_property_from_x509_subject(X509 * cert,tsi_peer_property * property,bool is_verified_root_cert)340 static tsi_result peer_property_from_x509_subject(X509* cert,
341                                                   tsi_peer_property* property,
342                                                   bool is_verified_root_cert) {
343   X509_NAME* subject_name = X509_get_subject_name(cert);
344   if (subject_name == nullptr) {
345     GRPC_TRACE_LOG(tsi, INFO) << "Could not get subject name from certificate.";
346     return TSI_NOT_FOUND;
347   }
348   BIO* bio = BIO_new(BIO_s_mem());
349   X509_NAME_print_ex(bio, subject_name, 0, XN_FLAG_RFC2253);
350   char* contents;
351   long len = BIO_get_mem_data(bio, &contents);
352   if (len < 0) {
353     LOG(ERROR) << "Could not get subject entry from certificate.";
354     BIO_free(bio);
355     return TSI_INTERNAL_ERROR;
356   }
357   tsi_result result;
358   if (!is_verified_root_cert) {
359     result = tsi_construct_string_peer_property(
360         TSI_X509_SUBJECT_PEER_PROPERTY, contents, static_cast<size_t>(len),
361         property);
362   } else {
363     result = tsi_construct_string_peer_property(
364         TSI_X509_VERIFIED_ROOT_CERT_SUBECT_PEER_PROPERTY, contents,
365         static_cast<size_t>(len), property);
366   }
367   BIO_free(bio);
368   return result;
369 }
370 
371 // Gets the X509 cert in PEM format as a tsi_peer_property.
add_pem_certificate(X509 * cert,tsi_peer_property * property)372 static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) {
373   BIO* bio = BIO_new(BIO_s_mem());
374   if (!PEM_write_bio_X509(bio, cert)) {
375     BIO_free(bio);
376     return TSI_INTERNAL_ERROR;
377   }
378   char* contents;
379   long len = BIO_get_mem_data(bio, &contents);
380   if (len <= 0) {
381     BIO_free(bio);
382     return TSI_INTERNAL_ERROR;
383   }
384   tsi_result result = tsi_construct_string_peer_property(
385       TSI_X509_PEM_CERT_PROPERTY, contents, static_cast<size_t>(len), property);
386   BIO_free(bio);
387   return result;
388 }
389 
390 // Gets the subject SANs from an X509 cert as a tsi_peer_property.
add_subject_alt_names_properties_to_peer(tsi_peer * peer,GENERAL_NAMES * subject_alt_names,size_t subject_alt_name_count,int * current_insert_index)391 static tsi_result add_subject_alt_names_properties_to_peer(
392     tsi_peer* peer, GENERAL_NAMES* subject_alt_names,
393     size_t subject_alt_name_count, int* current_insert_index) {
394   size_t i;
395   tsi_result result = TSI_OK;
396 
397   for (i = 0; i < subject_alt_name_count; i++) {
398     GENERAL_NAME* subject_alt_name =
399         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
400     if (subject_alt_name->type == GEN_DNS ||
401         subject_alt_name->type == GEN_EMAIL ||
402         subject_alt_name->type == GEN_URI) {
403       unsigned char* name = nullptr;
404       int name_size;
405       std::string property_name;
406       if (subject_alt_name->type == GEN_DNS) {
407         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName);
408         property_name = TSI_X509_DNS_PEER_PROPERTY;
409       } else if (subject_alt_name->type == GEN_EMAIL) {
410         name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.rfc822Name);
411         property_name = TSI_X509_EMAIL_PEER_PROPERTY;
412       } else {
413         name_size = ASN1_STRING_to_UTF8(
414             &name, subject_alt_name->d.uniformResourceIdentifier);
415         property_name = TSI_X509_URI_PEER_PROPERTY;
416       }
417       if (name_size < 0) {
418         LOG(ERROR) << "Could not get utf8 from asn1 string.";
419         result = TSI_INTERNAL_ERROR;
420         break;
421       }
422       result = tsi_construct_string_peer_property(
423           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
424           reinterpret_cast<const char*>(name), static_cast<size_t>(name_size),
425           &peer->properties[(*current_insert_index)++]);
426       if (result != TSI_OK) {
427         OPENSSL_free(name);
428         break;
429       }
430       result = tsi_construct_string_peer_property(
431           property_name.c_str(), reinterpret_cast<const char*>(name),
432           static_cast<size_t>(name_size),
433           &peer->properties[(*current_insert_index)++]);
434       OPENSSL_free(name);
435     } else if (subject_alt_name->type == GEN_IPADD) {
436       char ntop_buf[INET6_ADDRSTRLEN];
437       int af;
438 
439       if (subject_alt_name->d.iPAddress->length == 4) {
440         af = AF_INET;
441       } else if (subject_alt_name->d.iPAddress->length == 16) {
442         af = AF_INET6;
443       } else {
444         LOG(ERROR) << "SAN IP Address contained invalid IP";
445         result = TSI_INTERNAL_ERROR;
446         break;
447       }
448       const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data,
449                                    ntop_buf, INET6_ADDRSTRLEN);
450       if (name == nullptr) {
451         LOG(ERROR) << "Could not get IP string from asn1 octet.";
452         result = TSI_INTERNAL_ERROR;
453         break;
454       }
455 
456       result = tsi_construct_string_peer_property_from_cstring(
457           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name,
458           &peer->properties[(*current_insert_index)++]);
459       if (result != TSI_OK) break;
460       result = tsi_construct_string_peer_property_from_cstring(
461           TSI_X509_IP_PEER_PROPERTY, name,
462           &peer->properties[(*current_insert_index)++]);
463     } else {
464       result = tsi_construct_string_peer_property_from_cstring(
465           TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, "other types of SAN",
466           &peer->properties[(*current_insert_index)++]);
467     }
468     if (result != TSI_OK) break;
469   }
470   return result;
471 }
472 
473 // Gets information about the peer's X509 cert as a tsi_peer object.
peer_from_x509(X509 * cert,int include_certificate_type,tsi_peer * peer)474 static tsi_result peer_from_x509(X509* cert, int include_certificate_type,
475                                  tsi_peer* peer) {
476   // TODO(jboeuf): Maybe add more properties.
477   GENERAL_NAMES* subject_alt_names = static_cast<GENERAL_NAMES*>(
478       X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
479   int subject_alt_name_count =
480       (subject_alt_names != nullptr)
481           ? static_cast<int>(sk_GENERAL_NAME_num(subject_alt_names))
482           : 0;
483   size_t property_count;
484   tsi_result result;
485   CHECK_GE(subject_alt_name_count, 0);
486   property_count = (include_certificate_type ? size_t{1} : 0) +
487                    3 /* subject, common name, certificate */ +
488                    static_cast<size_t>(subject_alt_name_count);
489   for (int i = 0; i < subject_alt_name_count; i++) {
490     GENERAL_NAME* subject_alt_name =
491         sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
492     // TODO(zhenlian): Clean up tsi_peer to avoid duplicate entries.
493     // URI, DNS, email and ip address SAN fields are plumbed to tsi_peer, in
494     // addition to all SAN fields (results in duplicate values). This code
495     // snippet updates property_count accordingly.
496     if (subject_alt_name->type == GEN_URI ||
497         subject_alt_name->type == GEN_DNS ||
498         subject_alt_name->type == GEN_EMAIL ||
499         subject_alt_name->type == GEN_IPADD) {
500       property_count += 1;
501     }
502   }
503   result = tsi_construct_peer(property_count, peer);
504   if (result != TSI_OK) return result;
505   int current_insert_index = 0;
506   do {
507     if (include_certificate_type) {
508       result = tsi_construct_string_peer_property_from_cstring(
509           TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
510           &peer->properties[current_insert_index++]);
511       if (result != TSI_OK) break;
512     }
513 
514     result = peer_property_from_x509_subject(
515         cert, &peer->properties[current_insert_index++],
516         /*is_verified_root_cert=*/false);
517     if (result != TSI_OK) break;
518 
519     result = peer_property_from_x509_common_name(
520         cert, &peer->properties[current_insert_index++]);
521     if (result != TSI_OK) break;
522 
523     result =
524         add_pem_certificate(cert, &peer->properties[current_insert_index++]);
525     if (result != TSI_OK) break;
526 
527     if (subject_alt_name_count != 0) {
528       result = add_subject_alt_names_properties_to_peer(
529           peer, subject_alt_names, static_cast<size_t>(subject_alt_name_count),
530           &current_insert_index);
531       if (result != TSI_OK) break;
532     }
533   } while (false);
534 
535   if (subject_alt_names != nullptr) {
536     sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free);
537   }
538   if (result != TSI_OK) tsi_peer_destruct(peer);
539 
540   CHECK((int)peer->property_count == current_insert_index);
541   return result;
542 }
543 
544 // Loads an in-memory PEM certificate chain into the SSL context.
ssl_ctx_use_certificate_chain(SSL_CTX * context,const char * pem_cert_chain,size_t pem_cert_chain_size)545 static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context,
546                                                 const char* pem_cert_chain,
547                                                 size_t pem_cert_chain_size) {
548   tsi_result result = TSI_OK;
549   X509* certificate = nullptr;
550   BIO* pem;
551   CHECK_LE(pem_cert_chain_size, static_cast<size_t>(INT_MAX));
552   pem = BIO_new_mem_buf(pem_cert_chain, static_cast<int>(pem_cert_chain_size));
553   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
554 
555   do {
556     certificate =
557         PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
558     if (certificate == nullptr) {
559       result = TSI_INVALID_ARGUMENT;
560       break;
561     }
562     if (!SSL_CTX_use_certificate(context, certificate)) {
563       result = TSI_INVALID_ARGUMENT;
564       break;
565     }
566     while (true) {
567       X509* certificate_authority =
568           PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
569       if (certificate_authority == nullptr) {
570         ERR_clear_error();
571         break;  // Done reading.
572       }
573       if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) {
574         X509_free(certificate_authority);
575         result = TSI_INVALID_ARGUMENT;
576         break;
577       }
578       // We don't need to free certificate_authority as its ownership has been
579       // transferred to the context. That is not the case for certificate
580       // though.
581       //
582     }
583   } while (false);
584 
585   if (certificate != nullptr) X509_free(certificate);
586   BIO_free(pem);
587   return result;
588 }
589 
590 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
ssl_ctx_use_engine_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)591 static tsi_result ssl_ctx_use_engine_private_key(SSL_CTX* context,
592                                                  const char* pem_key,
593                                                  size_t pem_key_size) {
594   tsi_result result = TSI_OK;
595   EVP_PKEY* private_key = nullptr;
596   ENGINE* engine = nullptr;
597   char* engine_name = nullptr;
598   // Parse key which is in following format engine:<engine_id>:<key_id>
599   do {
600     char* engine_start = (char*)pem_key + strlen(kSslEnginePrefix);
601     char* engine_end = (char*)strchr(engine_start, ':');
602     if (engine_end == nullptr) {
603       result = TSI_INVALID_ARGUMENT;
604       break;
605     }
606     char* key_id = engine_end + 1;
607     int engine_name_length = engine_end - engine_start;
608     if (engine_name_length == 0) {
609       result = TSI_INVALID_ARGUMENT;
610       break;
611     }
612     engine_name = static_cast<char*>(gpr_zalloc(engine_name_length + 1));
613     memcpy(engine_name, engine_start, engine_name_length);
614     VLOG(2) << "ENGINE key: " << engine_name;
615     ENGINE_load_dynamic();
616     engine = ENGINE_by_id(engine_name);
617     if (engine == nullptr) {
618       // If not available at ENGINE_DIR, use dynamic to load from
619       // current working directory.
620       engine = ENGINE_by_id("dynamic");
621       if (engine == nullptr) {
622         LOG(ERROR) << "Cannot load dynamic engine";
623         result = TSI_INVALID_ARGUMENT;
624         break;
625       }
626       if (!ENGINE_ctrl_cmd_string(engine, "ID", engine_name, 0) ||
627           !ENGINE_ctrl_cmd_string(engine, "DIR_LOAD", "2", 0) ||
628           !ENGINE_ctrl_cmd_string(engine, "DIR_ADD", ".", 0) ||
629           !ENGINE_ctrl_cmd_string(engine, "LIST_ADD", "1", 0) ||
630           !ENGINE_ctrl_cmd_string(engine, "LOAD", NULL, 0)) {
631         LOG(ERROR) << "Cannot find engine";
632         result = TSI_INVALID_ARGUMENT;
633         break;
634       }
635     }
636     if (!ENGINE_set_default(engine, ENGINE_METHOD_ALL)) {
637       LOG(ERROR) << "ENGINE_set_default with ENGINE_METHOD_ALL failed";
638       result = TSI_INVALID_ARGUMENT;
639       break;
640     }
641     if (!ENGINE_init(engine)) {
642       LOG(ERROR) << "ENGINE_init failed";
643       result = TSI_INVALID_ARGUMENT;
644       break;
645     }
646     private_key = ENGINE_load_private_key(engine, key_id, 0, 0);
647     if (private_key == nullptr) {
648       LOG(ERROR) << "ENGINE_load_private_key failed";
649       result = TSI_INVALID_ARGUMENT;
650       break;
651     }
652     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
653       LOG(ERROR) << "SSL_CTX_use_PrivateKey failed";
654       result = TSI_INVALID_ARGUMENT;
655       break;
656     }
657   } while (0);
658   if (engine != nullptr) ENGINE_free(engine);
659   if (private_key != nullptr) EVP_PKEY_free(private_key);
660   if (engine_name != nullptr) gpr_free(engine_name);
661   return result;
662 }
663 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
664 
ssl_ctx_use_pem_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)665 static tsi_result ssl_ctx_use_pem_private_key(SSL_CTX* context,
666                                               const char* pem_key,
667                                               size_t pem_key_size) {
668   tsi_result result = TSI_OK;
669   EVP_PKEY* private_key = nullptr;
670   BIO* pem;
671   CHECK_LE(pem_key_size, static_cast<size_t>(INT_MAX));
672   pem = BIO_new_mem_buf(pem_key, static_cast<int>(pem_key_size));
673   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
674   do {
675     private_key =
676         PEM_read_bio_PrivateKey(pem, nullptr, nullptr, const_cast<char*>(""));
677     if (private_key == nullptr) {
678       result = TSI_INVALID_ARGUMENT;
679       break;
680     }
681     if (!SSL_CTX_use_PrivateKey(context, private_key)) {
682       result = TSI_INVALID_ARGUMENT;
683       break;
684     }
685   } while (false);
686   if (private_key != nullptr) EVP_PKEY_free(private_key);
687   BIO_free(pem);
688   return result;
689 }
690 
691 // Loads an in-memory PEM private key into the SSL context.
ssl_ctx_use_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)692 static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key,
693                                           size_t pem_key_size) {
694 // BoringSSL does not have ENGINE support
695 #if !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
696   if (strncmp(pem_key, kSslEnginePrefix, strlen(kSslEnginePrefix)) == 0) {
697     return ssl_ctx_use_engine_private_key(context, pem_key, pem_key_size);
698   } else
699 #endif  // !defined(OPENSSL_IS_BORINGSSL) && !defined(OPENSSL_NO_ENGINE)
700   {
701     return ssl_ctx_use_pem_private_key(context, pem_key, pem_key_size);
702   }
703 }
704 
705 // Loads in-memory PEM verification certs into the SSL context and optionally
706 // returns the verification cert names (root_names can be NULL).
x509_store_load_certs(X509_STORE * cert_store,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_names)707 static tsi_result x509_store_load_certs(X509_STORE* cert_store,
708                                         const char* pem_roots,
709                                         size_t pem_roots_size,
710                                         STACK_OF(X509_NAME) * *root_names) {
711   tsi_result result = TSI_OK;
712   size_t num_roots = 0;
713   X509* root = nullptr;
714   X509_NAME* root_name = nullptr;
715   BIO* pem;
716   CHECK_LE(pem_roots_size, static_cast<size_t>(INT_MAX));
717   pem = BIO_new_mem_buf(pem_roots, static_cast<int>(pem_roots_size));
718   if (cert_store == nullptr) return TSI_INVALID_ARGUMENT;
719   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
720   if (root_names != nullptr) {
721     *root_names = sk_X509_NAME_new_null();
722     if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES;
723   }
724 
725   while (true) {
726     root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, const_cast<char*>(""));
727     if (root == nullptr) {
728       ERR_clear_error();
729       break;  // We're at the end of stream.
730     }
731     if (root_names != nullptr) {
732       root_name = X509_get_subject_name(root);
733       if (root_name == nullptr) {
734         LOG(ERROR) << "Could not get name from root certificate.";
735         result = TSI_INVALID_ARGUMENT;
736         break;
737       }
738       root_name = X509_NAME_dup(root_name);
739       if (root_name == nullptr) {
740         result = TSI_OUT_OF_RESOURCES;
741         break;
742       }
743       sk_X509_NAME_push(*root_names, root_name);
744       root_name = nullptr;
745     }
746     ERR_clear_error();
747     if (!X509_STORE_add_cert(cert_store, root)) {
748       unsigned long error = ERR_get_error();
749       if (ERR_GET_LIB(error) != ERR_LIB_X509 ||
750           ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) {
751         LOG(ERROR) << "Could not add root certificate to ssl context.";
752         result = TSI_INTERNAL_ERROR;
753         break;
754       }
755     }
756     X509_free(root);
757     num_roots++;
758   }
759   if (num_roots == 0) {
760     LOG(ERROR) << "Could not load any root certificate.";
761     result = TSI_INVALID_ARGUMENT;
762   }
763 
764   if (result != TSI_OK) {
765     if (root != nullptr) X509_free(root);
766     if (root_names != nullptr) {
767       sk_X509_NAME_pop_free(*root_names, X509_NAME_free);
768       *root_names = nullptr;
769       if (root_name != nullptr) X509_NAME_free(root_name);
770     }
771   }
772   BIO_free(pem);
773   return result;
774 }
775 
ssl_ctx_load_verification_certs(SSL_CTX * context,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_name)776 static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context,
777                                                   const char* pem_roots,
778                                                   size_t pem_roots_size,
779                                                   STACK_OF(X509_NAME) *
780                                                       *root_name) {
781   X509_STORE* cert_store = SSL_CTX_get_cert_store(context);
782   X509_STORE_set_flags(cert_store,
783                        X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST);
784   return x509_store_load_certs(cert_store, pem_roots, pem_roots_size,
785                                root_name);
786 }
787 
788 // Populates the SSL context with a private key and a cert chain, and sets the
789 // cipher list and the ephemeral ECDH key.
populate_ssl_context(SSL_CTX * context,const tsi_ssl_pem_key_cert_pair * key_cert_pair,const char * cipher_list)790 static tsi_result populate_ssl_context(
791     SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair,
792     const char* cipher_list) {
793   tsi_result result = TSI_OK;
794   if (key_cert_pair != nullptr) {
795     if (key_cert_pair->cert_chain != nullptr) {
796       result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain,
797                                              strlen(key_cert_pair->cert_chain));
798       if (result != TSI_OK) {
799         LOG(ERROR) << "Invalid cert chain file.";
800         return result;
801       }
802     }
803     if (key_cert_pair->private_key != nullptr) {
804       result = ssl_ctx_use_private_key(context, key_cert_pair->private_key,
805                                        strlen(key_cert_pair->private_key));
806       if (result != TSI_OK || !SSL_CTX_check_private_key(context)) {
807         LOG(ERROR) << "Invalid private key.";
808         return result != TSI_OK ? result : TSI_INVALID_ARGUMENT;
809       }
810     }
811   }
812   if ((cipher_list != nullptr) &&
813       !SSL_CTX_set_cipher_list(context, cipher_list)) {
814     LOG(ERROR) << "Invalid cipher list: " << cipher_list;
815     return TSI_INVALID_ARGUMENT;
816   }
817   {
818 #if OPENSSL_VERSION_NUMBER < 0x30000000L
819     EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
820     if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) {
821       LOG(ERROR) << "Could not set ephemeral ECDH key.";
822       EC_KEY_free(ecdh);
823       return TSI_INTERNAL_ERROR;
824     }
825     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
826     EC_KEY_free(ecdh);
827 #else
828     if (!SSL_CTX_set1_groups(context, kSslEcCurveNames, 1)) {
829       LOG(ERROR) << "Could not set ephemeral ECDH key.";
830       return TSI_INTERNAL_ERROR;
831     }
832     SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
833 #endif
834   }
835   return TSI_OK;
836 }
837 
838 // Extracts the CN and the SANs from an X509 cert as a peer object.
tsi_ssl_extract_x509_subject_names_from_pem_cert(const char * pem_cert,tsi_peer * peer)839 tsi_result tsi_ssl_extract_x509_subject_names_from_pem_cert(
840     const char* pem_cert, tsi_peer* peer) {
841   tsi_result result = TSI_OK;
842   X509* cert = nullptr;
843   BIO* pem;
844   pem = BIO_new_mem_buf(pem_cert, static_cast<int>(strlen(pem_cert)));
845   if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
846 
847   cert = PEM_read_bio_X509(pem, nullptr, nullptr, const_cast<char*>(""));
848   if (cert == nullptr) {
849     LOG(ERROR) << "Invalid certificate";
850     result = TSI_INVALID_ARGUMENT;
851   } else {
852     result = peer_from_x509(cert, 0, peer);
853   }
854   if (cert != nullptr) X509_free(cert);
855   BIO_free(pem);
856   return result;
857 }
858 
859 // Builds the alpn protocol name list according to rfc 7301.
build_alpn_protocol_name_list(const char ** alpn_protocols,uint16_t num_alpn_protocols,unsigned char ** protocol_name_list,size_t * protocol_name_list_length)860 static tsi_result build_alpn_protocol_name_list(
861     const char** alpn_protocols, uint16_t num_alpn_protocols,
862     unsigned char** protocol_name_list, size_t* protocol_name_list_length) {
863   uint16_t i;
864   unsigned char* current;
865   *protocol_name_list = nullptr;
866   *protocol_name_list_length = 0;
867   if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT;
868   for (i = 0; i < num_alpn_protocols; i++) {
869     size_t length =
870         alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]);
871     if (length == 0 || length > 255) {
872       LOG(ERROR) << "Invalid protocol name length: " << length;
873       return TSI_INVALID_ARGUMENT;
874     }
875     *protocol_name_list_length += length + 1;
876   }
877   *protocol_name_list =
878       static_cast<unsigned char*>(gpr_malloc(*protocol_name_list_length));
879   if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES;
880   current = *protocol_name_list;
881   for (i = 0; i < num_alpn_protocols; i++) {
882     size_t length = strlen(alpn_protocols[i]);
883     *(current++) = static_cast<uint8_t>(length);  // max checked above.
884     memcpy(current, alpn_protocols[i], length);
885     current += length;
886   }
887   // Safety check.
888   if ((current < *protocol_name_list) ||
889       (static_cast<uintptr_t>(current - *protocol_name_list) !=
890        *protocol_name_list_length)) {
891     return TSI_INTERNAL_ERROR;
892   }
893   return TSI_OK;
894 }
895 
896 // This callback is invoked when the CRL has been verified and will soft-fail
897 // errors in verification depending on certain error types.
verify_cb(int ok,X509_STORE_CTX * ctx)898 static int verify_cb(int ok, X509_STORE_CTX* ctx) {
899   int cert_error = X509_STORE_CTX_get_error(ctx);
900   if (cert_error == X509_V_ERR_UNABLE_TO_GET_CRL) {
901     GRPC_TRACE_LOG(tsi, INFO)
902         << "Certificate verification failed to find relevant CRL file. "
903            "Ignoring error.";
904     return 1;
905   }
906   if (cert_error != 0) {
907     LOG(ERROR) << "Certificate verify failed with code " << cert_error;
908   }
909   return ok;
910 }
911 
912 // The verification callback is used for clients that don't really care about
913 // the server's certificate, but we need to pull it anyway, in case a higher
914 // layer wants to look at it. In this case the verification may fail, but
915 // we don't really care.
NullVerifyCallback(X509_STORE_CTX *,void *)916 static int NullVerifyCallback(X509_STORE_CTX* /*ctx*/, void* /*arg*/) {
917   return 1;
918 }
919 
RootCertExtractCallback(X509_STORE_CTX * ctx,void *)920 static int RootCertExtractCallback(X509_STORE_CTX* ctx, void* /*arg*/) {
921   int ret = 1;
922   // Verification was successful. Get the verified chain from the X509_STORE_CTX
923   // and put the root on the SSL object so that we have access to it when
924   // populating the tsi_peer. On error extracting the root, we return success
925   // anyway and proceed with the connection, to preserve the behavior of an
926   // older version of this code.
927 #if OPENSSL_VERSION_NUMBER >= 0x10100000
928   STACK_OF(X509)* chain = X509_STORE_CTX_get0_chain(ctx);
929 #else
930   STACK_OF(X509)* chain = X509_STORE_CTX_get_chain(ctx);
931 #endif
932   if (chain == nullptr) {
933     return ret;
934   }
935 
936   // The root cert is the last in the chain
937   size_t chain_length = sk_X509_num(chain);
938   if (chain_length == 0) {
939     return ret;
940   }
941   X509* root_cert = sk_X509_value(chain, chain_length - 1);
942   if (root_cert == nullptr) {
943     return ret;
944   }
945 
946   ERR_clear_error();
947   int ssl_index = SSL_get_ex_data_X509_STORE_CTX_idx();
948   if (ssl_index < 0) {
949     char err_str[256];
950     ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
951     LOG(ERROR) << "error getting the SSL index from the X509_STORE_CTX: "
952                << err_str;
953     return ret;
954   }
955   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, ssl_index));
956   if (ssl == nullptr) {
957     return ret;
958   }
959 
960   // Free the old root and save the new one. There should not be an old root,
961   // but if renegotiation is not disabled (required by RFC 9113, Section
962   // 9.2.1), it is possible that this callback run multiple times for a single
963   // connection. gRPC does not always disable renegotiation. See
964   // https://github.com/grpc/grpc/issues/35368
965   X509_free(static_cast<X509*>(
966       SSL_get_ex_data(ssl, g_ssl_ex_verified_root_cert_index)));
967   int success =
968       SSL_set_ex_data(ssl, g_ssl_ex_verified_root_cert_index, root_cert);
969   if (success == 0) {
970     GRPC_TRACE_LOG(tsi, INFO)
971         << "Could not set verified root cert in SSL's ex_data";
972   } else {
973 #if OPENSSL_VERSION_NUMBER >= 0x10100000L
974     X509_up_ref(root_cert);
975 #else
976     CRYPTO_add(&root_cert->references, 1, CRYPTO_LOCK_X509);
977 #endif
978   }
979   return ret;
980 }
981 
GetCrlProvider(X509_STORE_CTX * ctx)982 static grpc_core::experimental::CrlProvider* GetCrlProvider(
983     X509_STORE_CTX* ctx) {
984   ERR_clear_error();
985   int ssl_index = SSL_get_ex_data_X509_STORE_CTX_idx();
986   if (ssl_index < 0) {
987     char err_str[256];
988     ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
989     GRPC_TRACE_LOG(tsi, INFO)
990         << "error getting the SSL index from the X509_STORE_CTX while looking "
991            "up Crl: "
992         << err_str;
993     return nullptr;
994   }
995   SSL* ssl = static_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, ssl_index));
996   if (ssl == nullptr) {
997     GRPC_TRACE_LOG(tsi, INFO)
998         << "error while fetching from CrlProvider. SSL object is null";
999     return nullptr;
1000   }
1001   SSL_CTX* ssl_ctx = SSL_get_SSL_CTX(ssl);
1002   auto* provider = static_cast<grpc_core::experimental::CrlProvider*>(
1003       SSL_CTX_get_ex_data(ssl_ctx, g_ssl_ctx_ex_crl_provider_index));
1004   return provider;
1005 }
1006 
1007 // If a CRL is returned, the caller is the owner of the CRL and must make sure
1008 // it is freed.
GetCrlFromProvider(grpc_core::experimental::CrlProvider * provider,X509 * cert)1009 static absl::StatusOr<X509_CRL*> GetCrlFromProvider(
1010     grpc_core::experimental::CrlProvider* provider, X509* cert) {
1011   if (provider == nullptr) {
1012     return absl::InvalidArgumentError("CrlProvider is null.");
1013   }
1014   absl::StatusOr<std::string> issuer_name = grpc_core::IssuerFromCert(cert);
1015   if (!issuer_name.ok()) {
1016     GRPC_TRACE_LOG(tsi, INFO) << "Could not get certificate issuer name";
1017     return absl::InvalidArgumentError(issuer_name.status().message());
1018   }
1019   absl::StatusOr<std::string> akid = grpc_core::AkidFromCertificate(cert);
1020   std::string akid_to_use;
1021   if (!akid.ok()) {
1022     GRPC_TRACE_LOG(tsi, INFO)
1023         << "Could not get certificate authority key identifier.";
1024   } else {
1025     akid_to_use = *akid;
1026   }
1027 
1028   grpc_core::experimental::CertificateInfoImpl cert_impl(*issuer_name,
1029                                                          akid_to_use);
1030   std::shared_ptr<grpc_core::experimental::Crl> internal_crl =
1031       provider->GetCrl(cert_impl);
1032   // There wasn't a CRL found in the provider. Returning 0 will end up causing
1033   // OpenSSL to return X509_V_ERR_UNABLE_TO_GET_CRL. We then catch that error
1034   // and behave how we want for a missing CRL.
1035   // It is important to treat missing CRLs and empty CRLs differently.
1036   if (internal_crl == nullptr) {
1037     return absl::NotFoundError("Could not find Crl related to certificate.");
1038   }
1039   X509_CRL* crl =
1040       std::static_pointer_cast<grpc_core::experimental::CrlImpl>(internal_crl)
1041           ->crl();
1042 
1043   return X509_CRL_dup(crl);
1044 }
1045 
1046 // Perform the validation checks in RFC5280 6.3.3 to ensure the given CRL is
1047 // valid
1048 // returns true if the Crl is valid, false otherwise
ValidateCrl(X509 * cert,X509 * issuer,X509_CRL * crl)1049 static bool ValidateCrl(X509* cert, X509* issuer, X509_CRL* crl) {
1050   bool valid = true;
1051   // RFC5280 6.3.3
1052   // 6.3.3a we do not support distribution points
1053   // 6.3.3b verify issuer and scope
1054   valid = grpc_core::VerifyCrlCertIssuerNamesMatch(crl, cert);
1055   if (!valid) {
1056     VLOG(2) << "CRL and cert issuer names mismatched.";
1057     return valid;
1058   }
1059   valid = grpc_core::HasCrlSignBit(issuer);
1060   if (!valid) {
1061     VLOG(2) << "CRL issuer not allowed to sign CRLs.";
1062     return valid;
1063   }
1064   // 6.3.3c Not supporting deltas
1065   // 6.3.3d Not supporting reasons masks
1066   // 6.3.3e Not supporting reasons masks
1067   // 6.3.3f We only support direct CRLs so these paths are by definition the
1068   // same.
1069   // 6.3.3g Verify CRL Signature
1070   valid = grpc_core::VerifyCrlSignature(crl, issuer);
1071   if (!valid) {
1072     VLOG(2) << "Crl signature check failed.";
1073   }
1074   return valid;
1075 }
1076 
1077 // Check if a given certificate is revoked
1078 // Returns 1 if the certificate is not revoked, 0 if the certificate is revoked
CheckCertRevocation(grpc_core::experimental::CrlProvider * provider,X509 * cert,X509 * issuer)1079 static int CheckCertRevocation(grpc_core::experimental::CrlProvider* provider,
1080                                X509* cert, X509* issuer) {
1081   auto crl = GetCrlFromProvider(provider, cert);
1082   // Not finding a CRL is a specific behavior. Per RFC5280, not having a CRL to
1083   // check for a given certificate means that we cannot know for certain if the
1084   // status is Revoked or Unrevoked and instead is Undetermined. How a user
1085   // handles an Undetermined CRL is up to them. We use absl::IsNotFound as an
1086   // analogue for not finding the Crl from the provider, thus the certificate in
1087   // question is Undetermined.
1088   if (absl::IsNotFound(crl.status())) {
1089     // TODO(gtcooke94) knob for undetermined being revoked or unrevoked. By
1090     // default, unrevoked.
1091     return 1;
1092   } else if (!crl.ok()) {
1093     // This is an unexpected error, return false
1094     return 0;
1095   }
1096   // Validate the crl
1097   // RFC5280 6.3.3(a-i)
1098   if (!ValidateCrl(cert, issuer, *crl)) {
1099     X509_CRL_free(*crl);
1100     return 0;
1101   }
1102 
1103   // RFC5280 6.3.3j Actually check revocation
1104   // Look for serial number of certificate in CRL  X509_REVOKED* rev =
1105   // nullptr;
1106   X509_REVOKED* rev;
1107   if (X509_CRL_get0_by_cert(*crl, &rev, cert)) {
1108     // cert is revoked
1109     X509_CRL_free(*crl);
1110     return 0;
1111   }
1112   // The certificate is not revoked
1113   // RFC5280k - Not supported
1114   // RFC5280l - Not supported
1115   X509_CRL_free(*crl);
1116   return 1;
1117 }
1118 
1119 // Checks each certificate in the chain for revocation
1120 // returns 0 if any cert in the chain is revoked, 1 otherwise.
CheckChainRevocation(X509_STORE_CTX * ctx,grpc_core::experimental::CrlProvider * provider)1121 static int CheckChainRevocation(
1122     X509_STORE_CTX* ctx, grpc_core::experimental::CrlProvider* provider) {
1123 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1124   STACK_OF(X509)* chain = X509_STORE_CTX_get0_chain(ctx);
1125 #else
1126   STACK_OF(X509)* chain = X509_STORE_CTX_get_chain(ctx);
1127 #endif
1128   if (chain == nullptr) {
1129     return 0;
1130   }
1131   // BoringSSL returns a size_t (unsigned), while OpenSSL returns an int
1132   // (signed). In OpenSSL, a -1 can indicate a problem. By forcing it into a
1133   // size_t, a -1 return will result in the chain_length being a very large
1134   // number, so it will still fail this check because that very large number
1135   // will be >= kMaxChainLength
1136   size_t chain_length = sk_X509_num(chain);
1137   if (chain_length > kMaxChainLength || chain_length == 0) {
1138     return 0;
1139   }
1140   // Loop to < chain_length - 1 because the last cert is the trust anchor/root
1141   // which cannot be revoked
1142   for (size_t i = 0; i < chain_length - 1; i++) {
1143     X509* cert = sk_X509_value(chain, i);
1144     X509* issuer = sk_X509_value(chain, i + 1);
1145     int ret = CheckCertRevocation(provider, cert, issuer);
1146     if (ret != 1) {
1147       return ret;
1148     }
1149   }
1150   return 1;
1151 }
1152 
1153 // The custom verification function to set in OpenSSL using
1154 // X509_set_cert_verify_callback. This calls the standard OpenSSL procedure
1155 // (X509_verify_cert), then also extracts the root certificate in the built
1156 // chain and does revocation checks when a user has configured CrlProviders.
1157 // returns 1 on success, indicating a trusted chain to a root of trust was
1158 // found, 0 if a trusted chain could not be built.
CustomVerificationFunction(X509_STORE_CTX * ctx,void * arg)1159 static int CustomVerificationFunction(X509_STORE_CTX* ctx, void* arg) {
1160   int ret = X509_verify_cert(ctx);
1161   if (ret <= 0) {
1162     VLOG(2) << "Failed to verify cert chain.";
1163     // Verification failed. We shouldn't expect to have a verified chain, so
1164     // there is no need to attempt to extract the root cert from it, check for
1165     // revocation, or check anything else.
1166     return ret;
1167   }
1168   grpc_core::experimental::CrlProvider* provider = GetCrlProvider(ctx);
1169   if (provider != nullptr) {
1170     ret = CheckChainRevocation(ctx, provider);
1171     if (ret <= 0) {
1172       VLOG(2) << "The chain failed revocation checks.";
1173       return ret;
1174     }
1175   }
1176   return RootCertExtractCallback(ctx, arg);
1177 }
1178 
1179 // Sets the min and max TLS version of |ssl_context| to |min_tls_version| and
1180 // |max_tls_version|, respectively. Calling this method is a no-op when using
1181 // OpenSSL versions < 1.1.
tsi_set_min_and_max_tls_versions(SSL_CTX * ssl_context,tsi_tls_version min_tls_version,tsi_tls_version max_tls_version)1182 static tsi_result tsi_set_min_and_max_tls_versions(
1183     SSL_CTX* ssl_context, tsi_tls_version min_tls_version,
1184     tsi_tls_version max_tls_version) {
1185   if (ssl_context == nullptr) {
1186     GRPC_TRACE_LOG(tsi, INFO) << "Invalid nullptr argument to "
1187                                  "|tsi_set_min_and_max_tls_versions|.";
1188     return TSI_INVALID_ARGUMENT;
1189   }
1190 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1191   // Set the min TLS version of the SSL context if using OpenSSL version
1192   // >= 1.1.0. This OpenSSL version is required because the
1193   // |SSL_CTX_set_min_proto_version| and |SSL_CTX_set_max_proto_version| APIs
1194   // only exist in this version range.
1195   switch (min_tls_version) {
1196     case tsi_tls_version::TSI_TLS1_2:
1197       SSL_CTX_set_min_proto_version(ssl_context, TLS1_2_VERSION);
1198       break;
1199 #if defined(TLS1_3_VERSION)
1200     // If the library does not support TLS 1.3 and the caller requests a
1201     // minimum of TLS 1.3, then return an error because the caller's request
1202     // cannot be satisfied.
1203     case tsi_tls_version::TSI_TLS1_3:
1204       SSL_CTX_set_min_proto_version(ssl_context, TLS1_3_VERSION);
1205       break;
1206 #endif
1207     default:
1208       GRPC_TRACE_LOG(tsi, INFO) << "TLS version is not supported.";
1209       return TSI_FAILED_PRECONDITION;
1210   }
1211 
1212   // Set the max TLS version of the SSL context.
1213   switch (max_tls_version) {
1214     case tsi_tls_version::TSI_TLS1_2:
1215       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
1216       break;
1217     case tsi_tls_version::TSI_TLS1_3:
1218 #if defined(TLS1_3_VERSION)
1219       SSL_CTX_set_max_proto_version(ssl_context, TLS1_3_VERSION);
1220 #else
1221       // If the library does not support TLS 1.3, then set the max TLS version
1222       // to TLS 1.2 instead.
1223       SSL_CTX_set_max_proto_version(ssl_context, TLS1_2_VERSION);
1224 #endif
1225       break;
1226     default:
1227       GRPC_TRACE_LOG(tsi, INFO) << "TLS version is not supported.";
1228       return TSI_FAILED_PRECONDITION;
1229   }
1230 #endif
1231   return TSI_OK;
1232 }
1233 
1234 // --- tsi_ssl_root_certs_store methods implementation. ---
1235 
tsi_ssl_root_certs_store_create(const char * pem_roots)1236 tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
1237     const char* pem_roots) {
1238   if (pem_roots == nullptr) {
1239     LOG(ERROR) << "The root certificates are empty.";
1240     return nullptr;
1241   }
1242   tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
1243       gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
1244   if (root_store == nullptr) {
1245     LOG(ERROR) << "Could not allocate buffer for ssl_root_certs_store.";
1246     return nullptr;
1247   }
1248   root_store->store = X509_STORE_new();
1249   if (root_store->store == nullptr) {
1250     LOG(ERROR) << "Could not allocate buffer for X509_STORE.";
1251     gpr_free(root_store);
1252     return nullptr;
1253   }
1254   tsi_result result = x509_store_load_certs(root_store->store, pem_roots,
1255                                             strlen(pem_roots), nullptr);
1256   if (result != TSI_OK) {
1257     LOG(ERROR) << "Could not load root certificates.";
1258     X509_STORE_free(root_store->store);
1259     gpr_free(root_store);
1260     return nullptr;
1261   }
1262 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1263   X509_VERIFY_PARAM* param = X509_STORE_get0_param(root_store->store);
1264 #else
1265   X509_VERIFY_PARAM* param = root_store->store->param;
1266 #endif
1267   X509_VERIFY_PARAM_set_depth(param, kMaxChainLength);
1268   return root_store;
1269 }
1270 
tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self)1271 void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
1272   if (self == nullptr) return;
1273   X509_STORE_free(self->store);
1274   gpr_free(self);
1275 }
1276 
1277 // --- tsi_ssl_session_cache methods implementation. ---
1278 
tsi_ssl_session_cache_create_lru(size_t capacity)1279 tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
1280   // Pointer will be dereferenced by unref call.
1281   return tsi::SslSessionLRUCache::Create(capacity).release()->c_ptr();
1282 }
1283 
tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache)1284 void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
1285   // Pointer will be dereferenced by unref call.
1286   tsi::SslSessionLRUCache::FromC(cache)->Ref().release();
1287 }
1288 
tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache)1289 void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
1290   tsi::SslSessionLRUCache::FromC(cache)->Unref();
1291 }
1292 
1293 // --- tsi_frame_protector methods implementation. ---
1294 
ssl_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)1295 static tsi_result ssl_protector_protect(tsi_frame_protector* self,
1296                                         const unsigned char* unprotected_bytes,
1297                                         size_t* unprotected_bytes_size,
1298                                         unsigned char* protected_output_frames,
1299                                         size_t* protected_output_frames_size) {
1300   tsi_ssl_frame_protector* impl =
1301       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1302 
1303   return grpc_core::SslProtectorProtect(
1304       unprotected_bytes, impl->buffer_size, impl->buffer_offset, impl->buffer,
1305       impl->ssl, impl->network_io, unprotected_bytes_size,
1306       protected_output_frames, protected_output_frames_size);
1307 }
1308 
ssl_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)1309 static tsi_result ssl_protector_protect_flush(
1310     tsi_frame_protector* self, unsigned char* protected_output_frames,
1311     size_t* protected_output_frames_size, size_t* still_pending_size) {
1312   tsi_ssl_frame_protector* impl =
1313       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1314   return grpc_core::SslProtectorProtectFlush(
1315       impl->buffer_offset, impl->buffer, impl->ssl, impl->network_io,
1316       protected_output_frames, protected_output_frames_size,
1317       still_pending_size);
1318 }
1319 
ssl_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)1320 static tsi_result ssl_protector_unprotect(
1321     tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
1322     size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
1323     size_t* unprotected_bytes_size) {
1324   tsi_ssl_frame_protector* impl =
1325       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1326   return grpc_core::SslProtectorUnprotect(
1327       protected_frames_bytes, impl->ssl, impl->network_io,
1328       protected_frames_bytes_size, unprotected_bytes, unprotected_bytes_size);
1329 }
1330 
ssl_protector_destroy(tsi_frame_protector * self)1331 static void ssl_protector_destroy(tsi_frame_protector* self) {
1332   tsi_ssl_frame_protector* impl =
1333       reinterpret_cast<tsi_ssl_frame_protector*>(self);
1334   if (impl->buffer != nullptr) gpr_free(impl->buffer);
1335   if (impl->ssl != nullptr) SSL_free(impl->ssl);
1336   if (impl->network_io != nullptr) BIO_free(impl->network_io);
1337   gpr_free(self);
1338 }
1339 
1340 static const tsi_frame_protector_vtable frame_protector_vtable = {
1341     ssl_protector_protect,
1342     ssl_protector_protect_flush,
1343     ssl_protector_unprotect,
1344     ssl_protector_destroy,
1345 };
1346 
1347 // --- tsi_server_handshaker_factory methods implementation. ---
1348 
tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1349 static void tsi_ssl_handshaker_factory_destroy(
1350     tsi_ssl_handshaker_factory* factory) {
1351   if (factory == nullptr) return;
1352 
1353   if (factory->vtable != nullptr && factory->vtable->destroy != nullptr) {
1354     factory->vtable->destroy(factory);
1355   }
1356   // Note, we don't free(self) here because this object is always directly
1357   // embedded in another object. If tsi_ssl_handshaker_factory_init allocates
1358   // any memory, it should be free'd here.
1359 }
1360 
tsi_ssl_handshaker_factory_ref(tsi_ssl_handshaker_factory * factory)1361 static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
1362     tsi_ssl_handshaker_factory* factory) {
1363   if (factory == nullptr) return nullptr;
1364   gpr_refn(&factory->refcount, 1);
1365   return factory;
1366 }
1367 
tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * factory)1368 static void tsi_ssl_handshaker_factory_unref(
1369     tsi_ssl_handshaker_factory* factory) {
1370   if (factory == nullptr) return;
1371 
1372   if (gpr_unref(&factory->refcount)) {
1373     tsi_ssl_handshaker_factory_destroy(factory);
1374   }
1375 }
1376 
1377 static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
1378 
1379 // Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
1380 // allocating memory for the factory.
tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * factory)1381 static void tsi_ssl_handshaker_factory_init(
1382     tsi_ssl_handshaker_factory* factory) {
1383   CHECK_NE(factory, nullptr);
1384 
1385   factory->vtable = &handshaker_factory_vtable;
1386   gpr_ref_init(&factory->refcount, 1);
1387 }
1388 
1389 // Gets the X509 cert chain in PEM format as a tsi_peer_property.
tsi_ssl_get_cert_chain_contents(STACK_OF (X509)* peer_chain,tsi_peer_property * property)1390 tsi_result tsi_ssl_get_cert_chain_contents(STACK_OF(X509) * peer_chain,
1391                                            tsi_peer_property* property) {
1392   BIO* bio = BIO_new(BIO_s_mem());
1393   const auto peer_chain_len = sk_X509_num(peer_chain);
1394   for (auto i = decltype(peer_chain_len){0}; i < peer_chain_len; i++) {
1395     if (!PEM_write_bio_X509(bio, sk_X509_value(peer_chain, i))) {
1396       BIO_free(bio);
1397       return TSI_INTERNAL_ERROR;
1398     }
1399   }
1400   char* contents;
1401   long len = BIO_get_mem_data(bio, &contents);
1402   if (len <= 0) {
1403     BIO_free(bio);
1404     return TSI_INTERNAL_ERROR;
1405   }
1406   tsi_result result = tsi_construct_string_peer_property(
1407       TSI_X509_PEM_CERT_CHAIN_PROPERTY, contents, static_cast<size_t>(len),
1408       property);
1409   BIO_free(bio);
1410   return result;
1411 }
1412 
1413 // --- tsi_handshaker_result methods implementation. ---
ssl_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)1414 static tsi_result ssl_handshaker_result_extract_peer(
1415     const tsi_handshaker_result* self, tsi_peer* peer) {
1416   tsi_result result = TSI_OK;
1417   const unsigned char* alpn_selected = nullptr;
1418   unsigned int alpn_selected_len;
1419   const tsi_ssl_handshaker_result* impl =
1420       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1421   X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
1422   if (peer_cert != nullptr) {
1423     result = peer_from_x509(peer_cert, 1, peer);
1424     X509_free(peer_cert);
1425     if (result != TSI_OK) return result;
1426   }
1427 #if TSI_OPENSSL_ALPN_SUPPORT
1428   SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len);
1429 #endif  // TSI_OPENSSL_ALPN_SUPPORT
1430   if (alpn_selected == nullptr) {
1431     // Try npn.
1432     SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected,
1433                                    &alpn_selected_len);
1434   }
1435   // When called on the client side, the stack also contains the
1436   // peer's certificate; When called on the server side,
1437   // the peer's certificate is not present in the stack
1438   STACK_OF(X509)* peer_chain = SSL_get_peer_cert_chain(impl->ssl);
1439 
1440   X509* verified_root_cert = static_cast<X509*>(
1441       SSL_get_ex_data(impl->ssl, g_ssl_ex_verified_root_cert_index));
1442   // 1 is for session reused property.
1443   size_t new_property_count = peer->property_count + 3;
1444   if (alpn_selected != nullptr) new_property_count++;
1445   if (peer_chain != nullptr) new_property_count++;
1446   if (verified_root_cert != nullptr) new_property_count++;
1447   tsi_peer_property* new_properties = static_cast<tsi_peer_property*>(
1448       gpr_zalloc(sizeof(*new_properties) * new_property_count));
1449   for (size_t i = 0; i < peer->property_count; i++) {
1450     new_properties[i] = peer->properties[i];
1451   }
1452   if (peer->properties != nullptr) gpr_free(peer->properties);
1453   peer->properties = new_properties;
1454   // Add peer chain if available
1455   if (peer_chain != nullptr) {
1456     result = tsi_ssl_get_cert_chain_contents(
1457         peer_chain, &peer->properties[peer->property_count]);
1458     if (result == TSI_OK) peer->property_count++;
1459   }
1460   if (alpn_selected != nullptr) {
1461     result = tsi_construct_string_peer_property(
1462         TSI_SSL_ALPN_SELECTED_PROTOCOL,
1463         reinterpret_cast<const char*>(alpn_selected), alpn_selected_len,
1464         &peer->properties[peer->property_count]);
1465     if (result != TSI_OK) return result;
1466     peer->property_count++;
1467   }
1468   // Add security_level peer property.
1469   result = tsi_construct_string_peer_property_from_cstring(
1470       TSI_SECURITY_LEVEL_PEER_PROPERTY,
1471       tsi_security_level_to_string(TSI_PRIVACY_AND_INTEGRITY),
1472       &peer->properties[peer->property_count]);
1473   if (result != TSI_OK) return result;
1474   peer->property_count++;
1475 
1476   const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
1477   result = tsi_construct_string_peer_property_from_cstring(
1478       TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
1479       &peer->properties[peer->property_count]);
1480   if (result != TSI_OK) return result;
1481   peer->property_count++;
1482 
1483   if (verified_root_cert != nullptr) {
1484     result = peer_property_from_x509_subject(
1485         verified_root_cert, &peer->properties[peer->property_count], true);
1486     if (result != TSI_OK) {
1487       VLOG(2) << "Problem extracting subject from verified_root_cert. result: "
1488               << result;
1489     }
1490     peer->property_count++;
1491   }
1492 
1493   return result;
1494 }
1495 
ssl_handshaker_result_get_frame_protector_type(const tsi_handshaker_result *,tsi_frame_protector_type * frame_protector_type)1496 static tsi_result ssl_handshaker_result_get_frame_protector_type(
1497     const tsi_handshaker_result* /*self*/,
1498     tsi_frame_protector_type* frame_protector_type) {
1499   *frame_protector_type = TSI_FRAME_PROTECTOR_NORMAL;
1500   return TSI_OK;
1501 }
1502 
ssl_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)1503 static tsi_result ssl_handshaker_result_create_frame_protector(
1504     const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
1505     tsi_frame_protector** protector) {
1506   size_t actual_max_output_protected_frame_size =
1507       TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1508   tsi_ssl_handshaker_result* impl =
1509       reinterpret_cast<tsi_ssl_handshaker_result*>(
1510           const_cast<tsi_handshaker_result*>(self));
1511   tsi_ssl_frame_protector* protector_impl =
1512       static_cast<tsi_ssl_frame_protector*>(
1513           gpr_zalloc(sizeof(*protector_impl)));
1514 
1515   if (max_output_protected_frame_size != nullptr) {
1516     if (*max_output_protected_frame_size >
1517         TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) {
1518       *max_output_protected_frame_size =
1519           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1520     } else if (*max_output_protected_frame_size <
1521                TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) {
1522       *max_output_protected_frame_size =
1523           TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND;
1524     }
1525     actual_max_output_protected_frame_size = *max_output_protected_frame_size;
1526   }
1527   protector_impl->buffer_size =
1528       actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
1529   protector_impl->buffer =
1530       static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
1531   if (protector_impl->buffer == nullptr) {
1532     LOG(ERROR) << "Could not allocate buffer for tsi_ssl_frame_protector.";
1533     gpr_free(protector_impl);
1534     return TSI_INTERNAL_ERROR;
1535   }
1536 
1537   // Transfer ownership of ssl and network_io to the frame protector.
1538   protector_impl->ssl = impl->ssl;
1539   impl->ssl = nullptr;
1540   protector_impl->network_io = impl->network_io;
1541   impl->network_io = nullptr;
1542   protector_impl->base.vtable = &frame_protector_vtable;
1543   *protector = &protector_impl->base;
1544   return TSI_OK;
1545 }
1546 
ssl_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)1547 static tsi_result ssl_handshaker_result_get_unused_bytes(
1548     const tsi_handshaker_result* self, const unsigned char** bytes,
1549     size_t* bytes_size) {
1550   const tsi_ssl_handshaker_result* impl =
1551       reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1552   *bytes_size = impl->unused_bytes_size;
1553   *bytes = impl->unused_bytes;
1554   return TSI_OK;
1555 }
1556 
ssl_handshaker_result_destroy(tsi_handshaker_result * self)1557 static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
1558   tsi_ssl_handshaker_result* impl =
1559       reinterpret_cast<tsi_ssl_handshaker_result*>(self);
1560   SSL_free(impl->ssl);
1561   BIO_free(impl->network_io);
1562   gpr_free(impl->unused_bytes);
1563   gpr_free(impl);
1564 }
1565 
1566 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
1567     ssl_handshaker_result_extract_peer,
1568     ssl_handshaker_result_get_frame_protector_type,
1569     nullptr,  // create_zero_copy_grpc_protector
1570     ssl_handshaker_result_create_frame_protector,
1571     ssl_handshaker_result_get_unused_bytes,
1572     ssl_handshaker_result_destroy,
1573 };
1574 
ssl_handshaker_result_create(tsi_ssl_handshaker * handshaker,unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result,std::string * error)1575 static tsi_result ssl_handshaker_result_create(
1576     tsi_ssl_handshaker* handshaker, unsigned char* unused_bytes,
1577     size_t unused_bytes_size, tsi_handshaker_result** handshaker_result,
1578     std::string* error) {
1579   if (handshaker == nullptr || handshaker_result == nullptr ||
1580       (unused_bytes_size > 0 && unused_bytes == nullptr)) {
1581     if (error != nullptr) *error = "invalid argument";
1582     return TSI_INVALID_ARGUMENT;
1583   }
1584   tsi_ssl_handshaker_result* result =
1585       grpc_core::Zalloc<tsi_ssl_handshaker_result>();
1586   result->base.vtable = &handshaker_result_vtable;
1587   // Transfer ownership of ssl and network_io to the handshaker result.
1588   result->ssl = handshaker->ssl;
1589   handshaker->ssl = nullptr;
1590   result->network_io = handshaker->network_io;
1591   handshaker->network_io = nullptr;
1592   // Transfer ownership of |unused_bytes| to the handshaker result.
1593   result->unused_bytes = unused_bytes;
1594   result->unused_bytes_size = unused_bytes_size;
1595   *handshaker_result = &result->base;
1596   return TSI_OK;
1597 }
1598 
1599 // --- tsi_handshaker methods implementation. ---
1600 
ssl_handshaker_get_bytes_to_send_to_peer(tsi_ssl_handshaker * impl,unsigned char * bytes,size_t * bytes_size,std::string * error)1601 static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
1602     tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size,
1603     std::string* error) {
1604   int bytes_read_from_ssl = 0;
1605   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1606     if (error != nullptr) *error = "invalid argument";
1607     return TSI_INVALID_ARGUMENT;
1608   }
1609   CHECK_LE(*bytes_size, static_cast<size_t>(INT_MAX));
1610   bytes_read_from_ssl =
1611       BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
1612   if (bytes_read_from_ssl < 0) {
1613     *bytes_size = 0;
1614     if (!BIO_should_retry(impl->network_io)) {
1615       if (error != nullptr) *error = "error reading from BIO";
1616       impl->result = TSI_INTERNAL_ERROR;
1617       return impl->result;
1618     } else {
1619       return TSI_OK;
1620     }
1621   }
1622   *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
1623   return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
1624 }
1625 
ssl_handshaker_get_result(tsi_ssl_handshaker * impl)1626 static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
1627   if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
1628       SSL_is_init_finished(impl->ssl)) {
1629     impl->result = TSI_OK;
1630   }
1631   return impl->result;
1632 }
1633 
ssl_handshaker_do_handshake(tsi_ssl_handshaker * impl,std::string * error)1634 static tsi_result ssl_handshaker_do_handshake(tsi_ssl_handshaker* impl,
1635                                               std::string* error) {
1636   if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
1637     impl->result = TSI_OK;
1638     return impl->result;
1639   } else {
1640     ERR_clear_error();
1641     // Get ready to get some bytes from SSL.
1642     int ssl_result = SSL_do_handshake(impl->ssl);
1643     ssl_result = SSL_get_error(impl->ssl, ssl_result);
1644     switch (ssl_result) {
1645       case SSL_ERROR_WANT_READ:
1646         if (BIO_pending(impl->network_io) == 0) {
1647           // We need more data.
1648           return TSI_INCOMPLETE_DATA;
1649         } else {
1650           return TSI_OK;
1651         }
1652       case SSL_ERROR_NONE:
1653         return TSI_OK;
1654       case SSL_ERROR_WANT_WRITE:
1655         return TSI_DRAIN_BUFFER;
1656       default: {
1657         char err_str[256];
1658         ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
1659         long verify_result = SSL_get_verify_result(impl->ssl);
1660         std::string verify_result_str;
1661         if (verify_result != X509_V_OK) {
1662           const char* verify_err = X509_verify_cert_error_string(verify_result);
1663           verify_result_str = absl::StrCat(": ", verify_err);
1664         }
1665         LOG(INFO) << "Handshake failed with error "
1666                   << grpc_core::SslErrorString(ssl_result) << ": " << err_str
1667                   << verify_result_str;
1668         if (error != nullptr) {
1669           *error = absl::StrCat(grpc_core::SslErrorString(ssl_result), ": ",
1670                                 err_str, verify_result_str);
1671         }
1672         impl->result = TSI_PROTOCOL_FAILURE;
1673         return impl->result;
1674       }
1675     }
1676   }
1677 }
1678 
ssl_handshaker_process_bytes_from_peer(tsi_ssl_handshaker * impl,const unsigned char * bytes,size_t * bytes_size,std::string * error)1679 static tsi_result ssl_handshaker_process_bytes_from_peer(
1680     tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size,
1681     std::string* error) {
1682   int bytes_written_into_ssl_size = 0;
1683   if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1684     if (error != nullptr) *error = "invalid argument";
1685     return TSI_INVALID_ARGUMENT;
1686   }
1687   CHECK_LE(*bytes_size, static_cast<size_t>(INT_MAX));
1688   bytes_written_into_ssl_size =
1689       BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
1690   if (bytes_written_into_ssl_size < 0) {
1691     LOG(ERROR) << "Could not write to memory BIO.";
1692     if (error != nullptr) *error = "could not write to memory BIO";
1693     impl->result = TSI_INTERNAL_ERROR;
1694     return impl->result;
1695   }
1696   *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
1697   return ssl_handshaker_do_handshake(impl, error);
1698 }
1699 
ssl_handshaker_destroy(tsi_handshaker * self)1700 static void ssl_handshaker_destroy(tsi_handshaker* self) {
1701   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1702   SSL_free(impl->ssl);
1703   BIO_free(impl->network_io);
1704   gpr_free(impl->outgoing_bytes_buffer);
1705   tsi_ssl_handshaker_factory_unref(impl->factory_ref);
1706   gpr_free(impl);
1707 }
1708 
1709 // Removes the bytes remaining in |impl->SSL|'s read BIO and writes them to
1710 // |bytes_remaining|.
ssl_bytes_remaining(tsi_ssl_handshaker * impl,unsigned char ** bytes_remaining,size_t * bytes_remaining_size,std::string * error)1711 static tsi_result ssl_bytes_remaining(tsi_ssl_handshaker* impl,
1712                                       unsigned char** bytes_remaining,
1713                                       size_t* bytes_remaining_size,
1714                                       std::string* error) {
1715   if (impl == nullptr || bytes_remaining == nullptr ||
1716       bytes_remaining_size == nullptr) {
1717     if (error != nullptr) *error = "invalid argument";
1718     return TSI_INVALID_ARGUMENT;
1719   }
1720   // Attempt to read all of the bytes in SSL's read BIO. These bytes should
1721   // contain application data records that were appended to a handshake record
1722   // containing the ClientFinished or ServerFinished message.
1723   size_t bytes_in_ssl = BIO_pending(SSL_get_rbio(impl->ssl));
1724   if (bytes_in_ssl == 0) return TSI_OK;
1725   *bytes_remaining = static_cast<uint8_t*>(gpr_malloc(bytes_in_ssl));
1726   int bytes_read = BIO_read(SSL_get_rbio(impl->ssl), *bytes_remaining,
1727                             static_cast<int>(bytes_in_ssl));
1728   // If an unexpected number of bytes were read, return an error status and
1729   // free all of the bytes that were read.
1730   if (bytes_read < 0 || static_cast<size_t>(bytes_read) != bytes_in_ssl) {
1731     LOG(ERROR)
1732         << "Failed to read the expected number of bytes from SSL object.";
1733     gpr_free(*bytes_remaining);
1734     *bytes_remaining = nullptr;
1735     if (error != nullptr) {
1736       *error = "Failed to read the expected number of bytes from SSL object.";
1737     }
1738     return TSI_INTERNAL_ERROR;
1739   }
1740   *bytes_remaining_size = static_cast<size_t>(bytes_read);
1741   return TSI_OK;
1742 }
1743 
1744 // Write handshake data received from SSL to an unbound output buffer.
1745 // By doing that, we drain SSL bio buffer used to hold handshake data.
1746 // This API needs to be repeatedly called until all handshake data are
1747 // received from SSL.
ssl_handshaker_write_output_buffer(tsi_handshaker * self,size_t * bytes_written,std::string * error)1748 static tsi_result ssl_handshaker_write_output_buffer(tsi_handshaker* self,
1749                                                      size_t* bytes_written,
1750                                                      std::string* error) {
1751   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1752   tsi_result status = TSI_OK;
1753   size_t offset = *bytes_written;
1754   do {
1755     size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
1756     status = ssl_handshaker_get_bytes_to_send_to_peer(
1757         impl, impl->outgoing_bytes_buffer + offset, &to_send_size, error);
1758     offset += to_send_size;
1759     if (status == TSI_INCOMPLETE_DATA) {
1760       impl->outgoing_bytes_buffer_size *= 2;
1761       impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
1762           impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
1763     }
1764   } while (status == TSI_INCOMPLETE_DATA);
1765   *bytes_written = offset;
1766   return status;
1767 }
1768 
ssl_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb,void *,std::string * error)1769 static tsi_result ssl_handshaker_next(tsi_handshaker* self,
1770                                       const unsigned char* received_bytes,
1771                                       size_t received_bytes_size,
1772                                       const unsigned char** bytes_to_send,
1773                                       size_t* bytes_to_send_size,
1774                                       tsi_handshaker_result** handshaker_result,
1775                                       tsi_handshaker_on_next_done_cb /*cb*/,
1776                                       void* /*user_data*/, std::string* error) {
1777   // Input sanity check.
1778   if ((received_bytes_size > 0 && received_bytes == nullptr) ||
1779       bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
1780       handshaker_result == nullptr) {
1781     if (error != nullptr) *error = "invalid argument";
1782     return TSI_INVALID_ARGUMENT;
1783   }
1784   // If there are received bytes, process them first.
1785   tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1786   tsi_result status = TSI_OK;
1787   size_t bytes_written = 0;
1788   if (received_bytes_size > 0) {
1789     unsigned char* remaining_bytes_to_write_to_openssl =
1790         const_cast<unsigned char*>(received_bytes);
1791     size_t remaining_bytes_to_write_to_openssl_size = received_bytes_size;
1792     size_t number_bio_write_attempts = 0;
1793     while (remaining_bytes_to_write_to_openssl_size > 0 &&
1794            (status == TSI_OK || status == TSI_INCOMPLETE_DATA) &&
1795            number_bio_write_attempts < TSI_SSL_MAX_BIO_WRITE_ATTEMPTS) {
1796       ++number_bio_write_attempts;
1797       // Try to write all of the remaining bytes to the BIO.
1798       size_t bytes_written_to_openssl =
1799           remaining_bytes_to_write_to_openssl_size;
1800       status = ssl_handshaker_process_bytes_from_peer(
1801           impl, remaining_bytes_to_write_to_openssl, &bytes_written_to_openssl,
1802           error);
1803       // As long as the BIO is full, drive the SSL handshake to consume bytes
1804       // from the BIO. If the SSL handshake returns any bytes, write them to
1805       // the peer.
1806       while (status == TSI_DRAIN_BUFFER) {
1807         status =
1808             ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1809         if (status != TSI_OK) return status;
1810         status = ssl_handshaker_do_handshake(impl, error);
1811       }
1812       // Move the pointer to the first byte not yet successfully written to
1813       // the BIO.
1814       remaining_bytes_to_write_to_openssl_size -= bytes_written_to_openssl;
1815       remaining_bytes_to_write_to_openssl += bytes_written_to_openssl;
1816     }
1817   }
1818   if (status != TSI_OK) return status;
1819   // Get bytes to send to the peer, if available.
1820   status = ssl_handshaker_write_output_buffer(self, &bytes_written, error);
1821   if (status != TSI_OK) return status;
1822   *bytes_to_send = impl->outgoing_bytes_buffer;
1823   *bytes_to_send_size = bytes_written;
1824   // If handshake completes, create tsi_handshaker_result.
1825   if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
1826     *handshaker_result = nullptr;
1827   } else {
1828     // Any bytes that remain in |impl->ssl|'s read BIO after the handshake is
1829     // complete must be extracted and set to the unused bytes of the
1830     // handshaker result. This indicates to the gRPC stack that there are
1831     // bytes from the peer that must be processed.
1832     unsigned char* unused_bytes = nullptr;
1833     size_t unused_bytes_size = 0;
1834     status =
1835         ssl_bytes_remaining(impl, &unused_bytes, &unused_bytes_size, error);
1836     if (status != TSI_OK) return status;
1837     if (unused_bytes_size > received_bytes_size) {
1838       LOG(ERROR) << "More unused bytes than received bytes.";
1839       gpr_free(unused_bytes);
1840       if (error != nullptr) *error = "More unused bytes than received bytes.";
1841       return TSI_INTERNAL_ERROR;
1842     }
1843     status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
1844                                           handshaker_result, error);
1845     if (status == TSI_OK) {
1846       // Indicates that the handshake has completed and that a
1847       // handshaker_result has been created.
1848       self->handshaker_result_created = true;
1849       // Output Cipher information
1850       if (GRPC_TRACE_FLAG_ENABLED(tsi)) {
1851         tsi_ssl_handshaker_result* result =
1852             reinterpret_cast<tsi_ssl_handshaker_result*>(*handshaker_result);
1853         auto cipher = SSL_get_current_cipher(result->ssl);
1854         if (cipher != nullptr) {
1855           GRPC_TRACE_LOG(tsi, INFO) << absl::StrFormat(
1856               "SSL Cipher Version: %s Name: %s", SSL_CIPHER_get_version(cipher),
1857               SSL_CIPHER_get_name(cipher));
1858         }
1859       }
1860     }
1861   }
1862   return status;
1863 }
1864 
1865 static const tsi_handshaker_vtable handshaker_vtable = {
1866     nullptr,  // get_bytes_to_send_to_peer -- deprecated
1867     nullptr,  // process_bytes_from_peer   -- deprecated
1868     nullptr,  // get_result                -- deprecated
1869     nullptr,  // extract_peer              -- deprecated
1870     nullptr,  // create_frame_protector    -- deprecated
1871     ssl_handshaker_destroy,
1872     ssl_handshaker_next,
1873     nullptr,  // shutdown
1874 };
1875 
1876 // --- tsi_ssl_handshaker_factory common methods. ---
1877 
tsi_ssl_handshaker_resume_session(SSL * ssl,tsi::SslSessionLRUCache * session_cache)1878 static void tsi_ssl_handshaker_resume_session(
1879     SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
1880   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1881   if (server_name == nullptr) {
1882     return;
1883   }
1884   tsi::SslSessionPtr session = session_cache->Get(server_name);
1885   if (session != nullptr) {
1886     // SSL_set_session internally increments reference counter.
1887     SSL_set_session(ssl, session.get());
1888   }
1889 }
1890 
create_tsi_ssl_handshaker(SSL_CTX * ctx,int is_client,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_ssl_handshaker_factory * factory,tsi_handshaker ** handshaker)1891 static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
1892                                             const char* server_name_indication,
1893                                             size_t network_bio_buf_size,
1894                                             size_t ssl_bio_buf_size,
1895                                             tsi_ssl_handshaker_factory* factory,
1896                                             tsi_handshaker** handshaker) {
1897   SSL* ssl = SSL_new(ctx);
1898   BIO* network_io = nullptr;
1899   BIO* ssl_io = nullptr;
1900   tsi_ssl_handshaker* impl = nullptr;
1901   *handshaker = nullptr;
1902   if (ctx == nullptr) {
1903     LOG(ERROR) << "SSL Context is null. Should never happen.";
1904     return TSI_INTERNAL_ERROR;
1905   }
1906   if (ssl == nullptr) {
1907     return TSI_OUT_OF_RESOURCES;
1908   }
1909   SSL_set_info_callback(ssl, ssl_info_callback);
1910 
1911   if (!BIO_new_bio_pair(&network_io, network_bio_buf_size, &ssl_io,
1912                         ssl_bio_buf_size)) {
1913     LOG(ERROR) << "BIO_new_bio_pair failed.";
1914     SSL_free(ssl);
1915     return TSI_OUT_OF_RESOURCES;
1916   }
1917   SSL_set_bio(ssl, ssl_io, ssl_io);
1918   if (is_client) {
1919     int ssl_result;
1920     SSL_set_connect_state(ssl);
1921     // Skip if the SNI looks like an IP address because IP addressed are not
1922     // allowed as host names.
1923     if (server_name_indication != nullptr &&
1924         !looks_like_ip_address(server_name_indication)) {
1925       if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) {
1926         LOG(ERROR) << "Invalid server name indication "
1927                    << server_name_indication;
1928         SSL_free(ssl);
1929         BIO_free(network_io);
1930         return TSI_INTERNAL_ERROR;
1931       }
1932     }
1933     tsi_ssl_client_handshaker_factory* client_factory =
1934         reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1935     if (client_factory->session_cache != nullptr) {
1936       tsi_ssl_handshaker_resume_session(ssl,
1937                                         client_factory->session_cache.get());
1938     }
1939     ERR_clear_error();
1940     ssl_result = SSL_do_handshake(ssl);
1941     ssl_result = SSL_get_error(ssl, ssl_result);
1942     if (ssl_result != SSL_ERROR_WANT_READ) {
1943       LOG(ERROR)
1944           << "Unexpected error received from first SSL_do_handshake call: "
1945           << grpc_core::SslErrorString(ssl_result);
1946       SSL_free(ssl);
1947       BIO_free(network_io);
1948       return TSI_INTERNAL_ERROR;
1949     }
1950   } else {
1951     SSL_set_accept_state(ssl);
1952   }
1953 
1954   impl = grpc_core::Zalloc<tsi_ssl_handshaker>();
1955   impl->ssl = ssl;
1956   impl->network_io = network_io;
1957   impl->result = TSI_HANDSHAKE_IN_PROGRESS;
1958   impl->outgoing_bytes_buffer_size =
1959       TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
1960   impl->outgoing_bytes_buffer =
1961       static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
1962   impl->base.vtable = &handshaker_vtable;
1963   impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
1964   *handshaker = &impl->base;
1965   return TSI_OK;
1966 }
1967 
select_protocol_list(const unsigned char ** out,unsigned char * outlen,const unsigned char * client_list,size_t client_list_len,const unsigned char * server_list,size_t server_list_len)1968 static int select_protocol_list(const unsigned char** out,
1969                                 unsigned char* outlen,
1970                                 const unsigned char* client_list,
1971                                 size_t client_list_len,
1972                                 const unsigned char* server_list,
1973                                 size_t server_list_len) {
1974   const unsigned char* client_current = client_list;
1975   while (static_cast<unsigned int>(client_current - client_list) <
1976          client_list_len) {
1977     unsigned char client_current_len = *(client_current++);
1978     const unsigned char* server_current = server_list;
1979     while ((server_current >= server_list) &&
1980            static_cast<uintptr_t>(server_current - server_list) <
1981                server_list_len) {
1982       unsigned char server_current_len = *(server_current++);
1983       if ((client_current_len == server_current_len) &&
1984           !memcmp(client_current, server_current, server_current_len)) {
1985         *out = server_current;
1986         *outlen = server_current_len;
1987         return SSL_TLSEXT_ERR_OK;
1988       }
1989       server_current += server_current_len;
1990     }
1991     client_current += client_current_len;
1992   }
1993   return SSL_TLSEXT_ERR_NOACK;
1994 }
1995 
1996 // --- tsi_ssl_client_handshaker_factory methods implementation. ---
1997 
tsi_ssl_client_handshaker_factory_create_handshaker(tsi_ssl_client_handshaker_factory * factory,const char * server_name_indication,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)1998 tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
1999     tsi_ssl_client_handshaker_factory* factory,
2000     const char* server_name_indication, size_t network_bio_buf_size,
2001     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
2002   return create_tsi_ssl_handshaker(
2003       factory->ssl_context, 1, server_name_indication, network_bio_buf_size,
2004       ssl_bio_buf_size, &factory->base, handshaker);
2005 }
2006 
tsi_ssl_client_handshaker_factory_unref(tsi_ssl_client_handshaker_factory * factory)2007 void tsi_ssl_client_handshaker_factory_unref(
2008     tsi_ssl_client_handshaker_factory* factory) {
2009   if (factory == nullptr) return;
2010   tsi_ssl_handshaker_factory_unref(&factory->base);
2011 }
2012 
tsi_ssl_client_handshaker_factory_ref(tsi_ssl_client_handshaker_factory * client_factory)2013 tsi_ssl_client_handshaker_factory* tsi_ssl_client_handshaker_factory_ref(
2014     tsi_ssl_client_handshaker_factory* client_factory) {
2015   if (client_factory == nullptr) return nullptr;
2016   return reinterpret_cast<tsi_ssl_client_handshaker_factory*>(
2017       tsi_ssl_handshaker_factory_ref(&client_factory->base));
2018 }
2019 
tsi_ssl_client_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)2020 static void tsi_ssl_client_handshaker_factory_destroy(
2021     tsi_ssl_handshaker_factory* factory) {
2022   if (factory == nullptr) return;
2023   tsi_ssl_client_handshaker_factory* self =
2024       reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
2025   if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context);
2026   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
2027   self->session_cache.reset();
2028   self->key_logger.reset();
2029   gpr_free(self);
2030 }
2031 
client_handshaker_factory_npn_callback(SSL *,unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)2032 static int client_handshaker_factory_npn_callback(
2033     SSL* /*ssl*/, unsigned char** out, unsigned char* outlen,
2034     const unsigned char* in, unsigned int inlen, void* arg) {
2035   tsi_ssl_client_handshaker_factory* factory =
2036       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
2037   return select_protocol_list(const_cast<const unsigned char**>(out), outlen,
2038                               factory->alpn_protocol_list,
2039                               factory->alpn_protocol_list_length, in, inlen);
2040 }
2041 
2042 // --- tsi_ssl_server_handshaker_factory methods implementation. ---
2043 
tsi_ssl_server_handshaker_factory_create_handshaker(tsi_ssl_server_handshaker_factory * factory,size_t network_bio_buf_size,size_t ssl_bio_buf_size,tsi_handshaker ** handshaker)2044 tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
2045     tsi_ssl_server_handshaker_factory* factory, size_t network_bio_buf_size,
2046     size_t ssl_bio_buf_size, tsi_handshaker** handshaker) {
2047   if (factory->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
2048   // Create the handshaker with the first context. We will switch if needed
2049   // because of SNI in ssl_server_handshaker_factory_servername_callback.
2050   return create_tsi_ssl_handshaker(factory->ssl_contexts[0], 0, nullptr,
2051                                    network_bio_buf_size, ssl_bio_buf_size,
2052                                    &factory->base, handshaker);
2053 }
2054 
tsi_ssl_server_handshaker_factory_unref(tsi_ssl_server_handshaker_factory * factory)2055 void tsi_ssl_server_handshaker_factory_unref(
2056     tsi_ssl_server_handshaker_factory* factory) {
2057   if (factory == nullptr) return;
2058   tsi_ssl_handshaker_factory_unref(&factory->base);
2059 }
2060 
tsi_ssl_server_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)2061 static void tsi_ssl_server_handshaker_factory_destroy(
2062     tsi_ssl_handshaker_factory* factory) {
2063   if (factory == nullptr) return;
2064   tsi_ssl_server_handshaker_factory* self =
2065       reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
2066   size_t i;
2067   for (i = 0; i < self->ssl_context_count; i++) {
2068     if (self->ssl_contexts[i] != nullptr) {
2069       SSL_CTX_free(self->ssl_contexts[i]);
2070       tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]);
2071     }
2072   }
2073   if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts);
2074   if (self->ssl_context_x509_subject_names != nullptr) {
2075     gpr_free(self->ssl_context_x509_subject_names);
2076   }
2077   if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
2078   self->key_logger.reset();
2079   gpr_free(self);
2080 }
2081 
does_entry_match_name(absl::string_view entry,absl::string_view name)2082 static int does_entry_match_name(absl::string_view entry,
2083                                  absl::string_view name) {
2084   if (entry.empty()) return 0;
2085 
2086   // Take care of '.' terminations.
2087   if (name.back() == '.') {
2088     name.remove_suffix(1);
2089   }
2090   if (entry.back() == '.') {
2091     entry.remove_suffix(1);
2092     if (entry.empty()) return 0;
2093   }
2094 
2095   if (absl::EqualsIgnoreCase(name, entry)) {
2096     return 1;  // Perfect match.
2097   }
2098   if (entry.front() != '*') return 0;
2099 
2100   // Wildchar subdomain matching.
2101   if (entry.size() < 3 || entry[1] != '.') {  // At least *.x
2102     LOG(ERROR) << "Invalid wildchar entry.";
2103     return 0;
2104   }
2105   size_t name_subdomain_pos = name.find('.');
2106   if (name_subdomain_pos == absl::string_view::npos) return 0;
2107   if (name_subdomain_pos >= name.size() - 2) return 0;
2108   absl::string_view name_subdomain =
2109       name.substr(name_subdomain_pos + 1);  // Starts after the dot.
2110   entry.remove_prefix(2);                   // Remove *.
2111   size_t dot = name_subdomain.find('.');
2112   if (dot == absl::string_view::npos || dot == name_subdomain.size() - 1) {
2113     LOG(ERROR) << "Invalid toplevel subdomain: " << name_subdomain;
2114     return 0;
2115   }
2116   if (name_subdomain.back() == '.') {
2117     name_subdomain.remove_suffix(1);
2118   }
2119   return !entry.empty() && absl::EqualsIgnoreCase(name_subdomain, entry);
2120 }
2121 
ssl_server_handshaker_factory_servername_callback(SSL * ssl,int *,void * arg)2122 static int ssl_server_handshaker_factory_servername_callback(SSL* ssl,
2123                                                              int* /*ap*/,
2124                                                              void* arg) {
2125   tsi_ssl_server_handshaker_factory* impl =
2126       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2127   size_t i = 0;
2128   const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
2129   if (servername == nullptr || strlen(servername) == 0) {
2130     return SSL_TLSEXT_ERR_NOACK;
2131   }
2132 
2133   for (i = 0; i < impl->ssl_context_count; i++) {
2134     if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i],
2135                                   servername)) {
2136       SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]);
2137       return SSL_TLSEXT_ERR_OK;
2138     }
2139   }
2140   LOG(ERROR) << "No match found for server name: " << servername;
2141   return SSL_TLSEXT_ERR_NOACK;
2142 }
2143 
2144 #if TSI_OPENSSL_ALPN_SUPPORT
server_handshaker_factory_alpn_callback(SSL *,const unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)2145 static int server_handshaker_factory_alpn_callback(
2146     SSL* /*ssl*/, const unsigned char** out, unsigned char* outlen,
2147     const unsigned char* in, unsigned int inlen, void* arg) {
2148   tsi_ssl_server_handshaker_factory* factory =
2149       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2150   return select_protocol_list(out, outlen, in, inlen,
2151                               factory->alpn_protocol_list,
2152                               factory->alpn_protocol_list_length);
2153 }
2154 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2155 
server_handshaker_factory_npn_advertised_callback(SSL *,const unsigned char ** out,unsigned int * outlen,void * arg)2156 static int server_handshaker_factory_npn_advertised_callback(
2157     SSL* /*ssl*/, const unsigned char** out, unsigned int* outlen, void* arg) {
2158   tsi_ssl_server_handshaker_factory* factory =
2159       static_cast<tsi_ssl_server_handshaker_factory*>(arg);
2160   *out = factory->alpn_protocol_list;
2161   CHECK(factory->alpn_protocol_list_length <= UINT_MAX);
2162   *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
2163   return SSL_TLSEXT_ERR_OK;
2164 }
2165 
2166 /// This callback is called when new \a session is established and ready to
2167 /// be cached. This session can be reused for new connections to similar
2168 /// servers at later point of time.
2169 /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
2170 ///
2171 /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
server_handshaker_factory_new_session_callback(SSL * ssl,SSL_SESSION * session)2172 static int server_handshaker_factory_new_session_callback(
2173     SSL* ssl, SSL_SESSION* session) {
2174   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
2175   if (ssl_context == nullptr) {
2176     return 0;
2177   }
2178   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
2179   tsi_ssl_client_handshaker_factory* factory =
2180       static_cast<tsi_ssl_client_handshaker_factory*>(arg);
2181   const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
2182   if (server_name == nullptr) {
2183     return 0;
2184   }
2185   factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
2186   // Return 1 to indicate transferred ownership over the given session.
2187   return 1;
2188 }
2189 
2190 /// This callback is invoked at client or server when ssl/tls handshakes
2191 /// complete and keylogging is enabled.
2192 template <typename T>
ssl_keylogging_callback(const SSL * ssl,const char * info)2193 static void ssl_keylogging_callback(const SSL* ssl, const char* info) {
2194   SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
2195   CHECK_NE(ssl_context, nullptr);
2196   void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
2197   T* factory = static_cast<T*>(arg);
2198   factory->key_logger->LogSessionKeys(ssl_context, info);
2199 }
2200 
2201 // --- tsi_ssl_handshaker_factory constructors. ---
2202 
2203 static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
2204     tsi_ssl_client_handshaker_factory_destroy};
2205 
tsi_create_ssl_client_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pair,const char * pem_root_certs,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_client_handshaker_factory ** factory)2206 tsi_result tsi_create_ssl_client_handshaker_factory(
2207     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
2208     const char* pem_root_certs, const char* cipher_suites,
2209     const char** alpn_protocols, uint16_t num_alpn_protocols,
2210     tsi_ssl_client_handshaker_factory** factory) {
2211   tsi_ssl_client_handshaker_options options;
2212   options.pem_key_cert_pair = pem_key_cert_pair;
2213   options.pem_root_certs = pem_root_certs;
2214   options.cipher_suites = cipher_suites;
2215   options.alpn_protocols = alpn_protocols;
2216   options.num_alpn_protocols = num_alpn_protocols;
2217   return tsi_create_ssl_client_handshaker_factory_with_options(&options,
2218                                                                factory);
2219 }
2220 
tsi_create_ssl_client_handshaker_factory_with_options(const tsi_ssl_client_handshaker_options * options,tsi_ssl_client_handshaker_factory ** factory)2221 tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
2222     const tsi_ssl_client_handshaker_options* options,
2223     tsi_ssl_client_handshaker_factory** factory) {
2224   SSL_CTX* ssl_context = nullptr;
2225   tsi_ssl_client_handshaker_factory* impl = nullptr;
2226   tsi_result result = TSI_OK;
2227 
2228   gpr_once_init(&g_init_openssl_once, init_openssl);
2229 
2230   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2231   *factory = nullptr;
2232   if (options->pem_root_certs == nullptr && options->root_store == nullptr &&
2233       !options->skip_server_certificate_verification) {
2234     return TSI_INVALID_ARGUMENT;
2235   }
2236 
2237 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2238   ssl_context = SSL_CTX_new(TLS_method());
2239 #else
2240   ssl_context = SSL_CTX_new(TLSv1_2_method());
2241 #endif
2242 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2243   SSL_CTX_set_options(ssl_context, SSL_OP_NO_RENEGOTIATION);
2244 #endif
2245   if (ssl_context == nullptr) {
2246     grpc_core::LogSslErrorStack();
2247     LOG(ERROR) << "Could not create ssl context.";
2248     return TSI_INVALID_ARGUMENT;
2249   }
2250 
2251   result = tsi_set_min_and_max_tls_versions(
2252       ssl_context, options->min_tls_version, options->max_tls_version);
2253   if (result != TSI_OK) return result;
2254 
2255   impl = static_cast<tsi_ssl_client_handshaker_factory*>(
2256       gpr_zalloc(sizeof(*impl)));
2257   tsi_ssl_handshaker_factory_init(&impl->base);
2258   impl->base.vtable = &client_handshaker_factory_vtable;
2259   impl->ssl_context = ssl_context;
2260   if (options->session_cache != nullptr) {
2261     // Unref is called manually on factory destruction.
2262     impl->session_cache =
2263         reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache)
2264             ->Ref();
2265     SSL_CTX_sess_set_new_cb(ssl_context,
2266                             server_handshaker_factory_new_session_callback);
2267     SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT);
2268   }
2269 
2270 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2271   if (options->key_logger != nullptr) {
2272     impl->key_logger = options->key_logger->Ref();
2273     // SSL_CTX_set_keylog_callback is set here to register callback
2274     // when ssl/tls handshakes complete.
2275     SSL_CTX_set_keylog_callback(
2276         ssl_context,
2277         ssl_keylogging_callback<tsi_ssl_client_handshaker_factory>);
2278   }
2279 #endif
2280 
2281   if (options->session_cache != nullptr || options->key_logger != nullptr) {
2282     // Need to set factory at g_ssl_ctx_ex_factory_index
2283     SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl);
2284   }
2285 
2286   do {
2287     result = populate_ssl_context(ssl_context, options->pem_key_cert_pair,
2288                                   options->cipher_suites);
2289     if (result != TSI_OK) break;
2290 
2291 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2292     // X509_STORE_up_ref is only available since OpenSSL 1.1.
2293     if (options->root_store != nullptr) {
2294       X509_STORE_up_ref(options->root_store->store);
2295       SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
2296     }
2297 #endif
2298     if (OPENSSL_VERSION_NUMBER < 0x10100000 ||
2299         (options->root_store == nullptr &&
2300          options->pem_root_certs != nullptr)) {
2301       result = ssl_ctx_load_verification_certs(
2302           ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
2303           nullptr);
2304       X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context);
2305 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2306       X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2307 
2308 #else
2309       X509_VERIFY_PARAM* param = cert_store->param;
2310 #endif
2311 
2312       X509_VERIFY_PARAM_set_depth(param, kMaxChainLength);
2313       if (result != TSI_OK) {
2314         LOG(ERROR) << "Cannot load server root certificates.";
2315         break;
2316       }
2317     }
2318 
2319     if (options->num_alpn_protocols != 0) {
2320       result = build_alpn_protocol_name_list(
2321           options->alpn_protocols, options->num_alpn_protocols,
2322           &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2323       if (result != TSI_OK) {
2324         LOG(ERROR) << "Building alpn list failed with error "
2325                    << tsi_result_to_string(result);
2326         break;
2327       }
2328 #if TSI_OPENSSL_ALPN_SUPPORT
2329       CHECK(impl->alpn_protocol_list_length < UINT_MAX);
2330       if (SSL_CTX_set_alpn_protos(
2331               ssl_context, impl->alpn_protocol_list,
2332               static_cast<unsigned int>(impl->alpn_protocol_list_length))) {
2333         LOG(ERROR) << "Could not set alpn protocol list to context.";
2334         result = TSI_INVALID_ARGUMENT;
2335         break;
2336       }
2337 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2338       SSL_CTX_set_next_proto_select_cb(
2339           ssl_context, client_handshaker_factory_npn_callback, impl);
2340     }
2341   } while (false);
2342   if (result != TSI_OK) {
2343     tsi_ssl_handshaker_factory_unref(&impl->base);
2344     return result;
2345   }
2346   SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, nullptr);
2347   if (options->skip_server_certificate_verification) {
2348     SSL_CTX_set_cert_verify_callback(ssl_context, NullVerifyCallback, nullptr);
2349   } else {
2350     SSL_CTX_set_cert_verify_callback(ssl_context, CustomVerificationFunction,
2351                                      nullptr);
2352   }
2353 #if OPENSSL_VERSION_NUMBER >= 0x10100000 && !defined(LIBRESSL_VERSION_NUMBER)
2354   if (options->crl_provider != nullptr) {
2355     SSL_CTX_set_ex_data(impl->ssl_context, g_ssl_ctx_ex_crl_provider_index,
2356                         options->crl_provider.get());
2357   } else if (options->crl_directory != nullptr &&
2358              strcmp(options->crl_directory, "") != 0) {
2359     X509_STORE* cert_store = SSL_CTX_get_cert_store(ssl_context);
2360     X509_STORE_set_verify_cb(cert_store, verify_cb);
2361     if (!X509_STORE_load_locations(cert_store, nullptr,
2362                                    options->crl_directory)) {
2363       LOG(ERROR) << "Failed to load CRL File from directory.";
2364     } else {
2365       X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2366       X509_VERIFY_PARAM_set_flags(
2367           param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2368     }
2369   }
2370 #endif
2371 
2372   *factory = impl;
2373   return TSI_OK;
2374 }
2375 
2376 static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
2377     tsi_ssl_server_handshaker_factory_destroy};
2378 
tsi_create_ssl_server_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,int force_client_auth,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2379 tsi_result tsi_create_ssl_server_handshaker_factory(
2380     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2381     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2382     int force_client_auth, const char* cipher_suites,
2383     const char** alpn_protocols, uint16_t num_alpn_protocols,
2384     tsi_ssl_server_handshaker_factory** factory) {
2385   return tsi_create_ssl_server_handshaker_factory_ex(
2386       pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
2387       force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
2388                         : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
2389       cipher_suites, alpn_protocols, num_alpn_protocols, factory);
2390 }
2391 
tsi_create_ssl_server_handshaker_factory_ex(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,tsi_client_certificate_request_type client_certificate_request,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)2392 tsi_result tsi_create_ssl_server_handshaker_factory_ex(
2393     const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
2394     size_t num_key_cert_pairs, const char* pem_client_root_certs,
2395     tsi_client_certificate_request_type client_certificate_request,
2396     const char* cipher_suites, const char** alpn_protocols,
2397     uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
2398   tsi_ssl_server_handshaker_options options;
2399   options.pem_key_cert_pairs = pem_key_cert_pairs;
2400   options.num_key_cert_pairs = num_key_cert_pairs;
2401   options.pem_client_root_certs = pem_client_root_certs;
2402   options.client_certificate_request = client_certificate_request;
2403   options.cipher_suites = cipher_suites;
2404   options.alpn_protocols = alpn_protocols;
2405   options.num_alpn_protocols = num_alpn_protocols;
2406   return tsi_create_ssl_server_handshaker_factory_with_options(&options,
2407                                                                factory);
2408 }
2409 
tsi_create_ssl_server_handshaker_factory_with_options(const tsi_ssl_server_handshaker_options * options,tsi_ssl_server_handshaker_factory ** factory)2410 tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
2411     const tsi_ssl_server_handshaker_options* options,
2412     tsi_ssl_server_handshaker_factory** factory) {
2413   tsi_ssl_server_handshaker_factory* impl = nullptr;
2414   tsi_result result = TSI_OK;
2415   size_t i = 0;
2416 
2417   gpr_once_init(&g_init_openssl_once, init_openssl);
2418 
2419   if (factory == nullptr) return TSI_INVALID_ARGUMENT;
2420   *factory = nullptr;
2421   if (options->num_key_cert_pairs == 0 ||
2422       options->pem_key_cert_pairs == nullptr) {
2423     return TSI_INVALID_ARGUMENT;
2424   }
2425 
2426   impl = static_cast<tsi_ssl_server_handshaker_factory*>(
2427       gpr_zalloc(sizeof(*impl)));
2428   tsi_ssl_handshaker_factory_init(&impl->base);
2429   impl->base.vtable = &server_handshaker_factory_vtable;
2430 
2431   impl->ssl_contexts = static_cast<SSL_CTX**>(
2432       gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*)));
2433   impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>(
2434       gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer)));
2435   if (impl->ssl_contexts == nullptr ||
2436       impl->ssl_context_x509_subject_names == nullptr) {
2437     tsi_ssl_handshaker_factory_unref(&impl->base);
2438     return TSI_OUT_OF_RESOURCES;
2439   }
2440   impl->ssl_context_count = options->num_key_cert_pairs;
2441 
2442   if (options->num_alpn_protocols > 0) {
2443     result = build_alpn_protocol_name_list(
2444         options->alpn_protocols, options->num_alpn_protocols,
2445         &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
2446     if (result != TSI_OK) {
2447       tsi_ssl_handshaker_factory_unref(&impl->base);
2448       return result;
2449     }
2450   }
2451 
2452   if (options->key_logger != nullptr) {
2453     impl->key_logger = options->key_logger->Ref();
2454   }
2455 
2456   for (i = 0; i < options->num_key_cert_pairs; i++) {
2457     do {
2458 #if OPENSSL_VERSION_NUMBER >= 0x10100000
2459       impl->ssl_contexts[i] = SSL_CTX_new(TLS_method());
2460 #else
2461       impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
2462 #endif
2463 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2464       SSL_CTX_set_options(impl->ssl_contexts[i], SSL_OP_NO_RENEGOTIATION);
2465 #endif
2466       if (impl->ssl_contexts[i] == nullptr) {
2467         grpc_core::LogSslErrorStack();
2468         LOG(ERROR) << "Could not create ssl context.";
2469         result = TSI_OUT_OF_RESOURCES;
2470         break;
2471       }
2472 
2473       result = tsi_set_min_and_max_tls_versions(impl->ssl_contexts[i],
2474                                                 options->min_tls_version,
2475                                                 options->max_tls_version);
2476       if (result != TSI_OK) return result;
2477 
2478       result = populate_ssl_context(impl->ssl_contexts[i],
2479                                     &options->pem_key_cert_pairs[i],
2480                                     options->cipher_suites);
2481       if (result != TSI_OK) break;
2482 
2483       // TODO(elessar): Provide ability to disable session ticket keys.
2484 
2485       // Allow client cache sessions (it's needed for OpenSSL only).
2486       int set_sid_ctx_result = SSL_CTX_set_session_id_context(
2487           impl->ssl_contexts[i], kSslSessionIdContext,
2488           GPR_ARRAY_SIZE(kSslSessionIdContext));
2489       if (set_sid_ctx_result == 0) {
2490         LOG(ERROR) << "Failed to set session id context.";
2491         result = TSI_INTERNAL_ERROR;
2492         break;
2493       }
2494 
2495       if (options->session_ticket_key != nullptr) {
2496         if (SSL_CTX_set_tlsext_ticket_keys(
2497                 impl->ssl_contexts[i],
2498                 const_cast<char*>(options->session_ticket_key),
2499                 options->session_ticket_key_size) == 0) {
2500           LOG(ERROR) << "Invalid STEK size.";
2501           result = TSI_INVALID_ARGUMENT;
2502           break;
2503         }
2504       }
2505 
2506       if (options->pem_client_root_certs != nullptr) {
2507         STACK_OF(X509_NAME)* root_names = nullptr;
2508         result = ssl_ctx_load_verification_certs(
2509             impl->ssl_contexts[i], options->pem_client_root_certs,
2510             strlen(options->pem_client_root_certs),
2511             options->send_client_ca_list ? &root_names : nullptr);
2512         if (result != TSI_OK) {
2513           LOG(ERROR) << "Invalid verification certs.";
2514           break;
2515         }
2516         if (options->send_client_ca_list) {
2517           SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
2518         }
2519       }
2520       switch (options->client_certificate_request) {
2521         case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
2522           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
2523           break;
2524         case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2525           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
2526           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2527                                            NullVerifyCallback, nullptr);
2528           break;
2529         case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
2530           SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
2531           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2532                                            CustomVerificationFunction, nullptr);
2533           break;
2534         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
2535           SSL_CTX_set_verify(impl->ssl_contexts[i],
2536                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2537                              nullptr);
2538           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2539                                            NullVerifyCallback, nullptr);
2540           break;
2541         case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
2542           SSL_CTX_set_verify(impl->ssl_contexts[i],
2543                              SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
2544                              nullptr);
2545           SSL_CTX_set_cert_verify_callback(impl->ssl_contexts[i],
2546                                            CustomVerificationFunction, nullptr);
2547           break;
2548       }
2549 
2550 #if OPENSSL_VERSION_NUMBER >= 0x10100000 && !defined(LIBRESSL_VERSION_NUMBER)
2551       if (options->crl_provider != nullptr) {
2552         SSL_CTX_set_ex_data(impl->ssl_contexts[i],
2553                             g_ssl_ctx_ex_crl_provider_index,
2554                             options->crl_provider.get());
2555       } else if (options->crl_directory != nullptr &&
2556                  strcmp(options->crl_directory, "") != 0) {
2557         X509_STORE* cert_store = SSL_CTX_get_cert_store(impl->ssl_contexts[i]);
2558         X509_STORE_set_verify_cb(cert_store, verify_cb);
2559         if (!X509_STORE_load_locations(cert_store, nullptr,
2560                                        options->crl_directory)) {
2561           LOG(ERROR) << "Failed to load CRL File from directory.";
2562         } else {
2563           X509_VERIFY_PARAM* param = X509_STORE_get0_param(cert_store);
2564           X509_VERIFY_PARAM_set_flags(
2565               param, X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
2566         }
2567       }
2568 #endif
2569 
2570       result = tsi_ssl_extract_x509_subject_names_from_pem_cert(
2571           options->pem_key_cert_pairs[i].cert_chain,
2572           &impl->ssl_context_x509_subject_names[i]);
2573       if (result != TSI_OK) break;
2574 
2575       SSL_CTX_set_tlsext_servername_callback(
2576           impl->ssl_contexts[i],
2577           ssl_server_handshaker_factory_servername_callback);
2578       SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
2579 #if TSI_OPENSSL_ALPN_SUPPORT
2580       SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
2581                                  server_handshaker_factory_alpn_callback, impl);
2582 #endif  // TSI_OPENSSL_ALPN_SUPPORT
2583       SSL_CTX_set_next_protos_advertised_cb(
2584           impl->ssl_contexts[i],
2585           server_handshaker_factory_npn_advertised_callback, impl);
2586 
2587 #if OPENSSL_VERSION_NUMBER >= 0x10101000 && !defined(LIBRESSL_VERSION_NUMBER)
2588       // Register factory at index
2589       if (options->key_logger != nullptr) {
2590         // Need to set factory at g_ssl_ctx_ex_factory_index
2591         SSL_CTX_set_ex_data(impl->ssl_contexts[i], g_ssl_ctx_ex_factory_index,
2592                             impl);
2593         // SSL_CTX_set_keylog_callback is set here to register callback
2594         // when ssl/tls handshakes complete.
2595         SSL_CTX_set_keylog_callback(
2596             impl->ssl_contexts[i],
2597             ssl_keylogging_callback<tsi_ssl_server_handshaker_factory>);
2598       }
2599 #endif
2600     } while (false);
2601 
2602     if (result != TSI_OK) {
2603       tsi_ssl_handshaker_factory_unref(&impl->base);
2604       return result;
2605     }
2606   }
2607 
2608   *factory = impl;
2609   return TSI_OK;
2610 }
2611 
2612 // --- tsi_ssl utils. ---
2613 
tsi_ssl_peer_matches_name(const tsi_peer * peer,absl::string_view name)2614 int tsi_ssl_peer_matches_name(const tsi_peer* peer, absl::string_view name) {
2615   size_t i = 0;
2616   size_t san_count = 0;
2617   const tsi_peer_property* cn_property = nullptr;
2618   int like_ip = looks_like_ip_address(name);
2619 
2620   // Check the SAN first.
2621   for (i = 0; i < peer->property_count; i++) {
2622     const tsi_peer_property* property = &peer->properties[i];
2623     if (property->name == nullptr) continue;
2624     if (strcmp(property->name,
2625                TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
2626       san_count++;
2627 
2628       absl::string_view entry(property->value.data, property->value.length);
2629       if (!like_ip && does_entry_match_name(entry, name)) {
2630         return 1;
2631       } else if (like_ip && name == entry) {
2632         // IP Addresses are exact matches only.
2633         return 1;
2634       }
2635     } else if (strcmp(property->name,
2636                       TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) {
2637       cn_property = property;
2638     }
2639   }
2640 
2641   // If there's no SAN, try the CN, but only if its not like an IP Address
2642   if (san_count == 0 && cn_property != nullptr && !like_ip) {
2643     if (does_entry_match_name(absl::string_view(cn_property->value.data,
2644                                                 cn_property->value.length),
2645                               name)) {
2646       return 1;
2647     }
2648   }
2649 
2650   return 0;  // Not found.
2651 }
2652 
2653 // --- Testing support. ---
tsi_ssl_handshaker_factory_swap_vtable(tsi_ssl_handshaker_factory * factory,tsi_ssl_handshaker_factory_vtable * new_vtable)2654 const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
2655     tsi_ssl_handshaker_factory* factory,
2656     tsi_ssl_handshaker_factory_vtable* new_vtable) {
2657   CHECK_NE(factory, nullptr);
2658   CHECK_NE(factory->vtable, nullptr);
2659 
2660   const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
2661   factory->vtable = new_vtable;
2662   return orig_vtable;
2663 }
2664