1 /*
2 *
3 * Copyright 2015 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19 #include <grpc/support/port_platform.h>
20
21 #include "src/core/tsi/grpc_shadow_boringssl.h"
22
23 #include "src/core/tsi/ssl_transport_security.h"
24
25 #include <limits.h>
26 #include <string.h>
27
28 /* TODO(jboeuf): refactor inet_ntop into a portability header. */
29 /* Note: for whomever reads this and tries to refactor this, this
30 can't be in grpc, it has to be in gpr. */
31 #ifdef GPR_WINDOWS
32 #include <ws2tcpip.h>
33 #else
34 #include <arpa/inet.h>
35 #include <sys/socket.h>
36 #endif
37
38 #include <grpc/support/alloc.h>
39 #include <grpc/support/log.h>
40 #include <grpc/support/string_util.h>
41 #include <grpc/support/sync.h>
42 #include <grpc/support/thd_id.h>
43
44 extern "C" {
45 #include <openssl/bio.h>
46 #include <openssl/crypto.h> /* For OPENSSL_free */
47 #include <openssl/err.h>
48 #include <openssl/ssl.h>
49 #include <openssl/x509.h>
50 #include <openssl/x509v3.h>
51 }
52
53 #include "src/core/lib/gpr/useful.h"
54 #include "src/core/tsi/ssl/session_cache/ssl_session_cache.h"
55 #include "src/core/tsi/ssl_types.h"
56 #include "src/core/tsi/transport_security.h"
57
58 /* --- Constants. ---*/
59
60 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND 16384
61 #define TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND 1024
62 #define TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE 1024
63
64 /* Putting a macro like this and littering the source file with #if is really
65 bad practice.
66 TODO(jboeuf): refactor all the #if / #endif in a separate module. */
67 #ifndef TSI_OPENSSL_ALPN_SUPPORT
68 #define TSI_OPENSSL_ALPN_SUPPORT 1
69 #endif
70
71 /* TODO(jboeuf): I have not found a way to get this number dynamically from the
72 SSL structure. This is what we would ultimately want though... */
73 #define TSI_SSL_MAX_PROTECTION_OVERHEAD 100
74
75 /* --- Structure definitions. ---*/
76
77 struct tsi_ssl_root_certs_store {
78 X509_STORE* store;
79 };
80
81 struct tsi_ssl_handshaker_factory {
82 const tsi_ssl_handshaker_factory_vtable* vtable;
83 gpr_refcount refcount;
84 };
85
86 struct tsi_ssl_client_handshaker_factory {
87 tsi_ssl_handshaker_factory base;
88 SSL_CTX* ssl_context;
89 unsigned char* alpn_protocol_list;
90 size_t alpn_protocol_list_length;
91 grpc_core::RefCountedPtr<tsi::SslSessionLRUCache> session_cache;
92 };
93
94 struct tsi_ssl_server_handshaker_factory {
95 /* Several contexts to support SNI.
96 The tsi_peer array contains the subject names of the server certificates
97 associated with the contexts at the same index. */
98 tsi_ssl_handshaker_factory base;
99 SSL_CTX** ssl_contexts;
100 tsi_peer* ssl_context_x509_subject_names;
101 size_t ssl_context_count;
102 unsigned char* alpn_protocol_list;
103 size_t alpn_protocol_list_length;
104 };
105
106 typedef struct {
107 tsi_handshaker base;
108 SSL* ssl;
109 BIO* network_io;
110 tsi_result result;
111 unsigned char* outgoing_bytes_buffer;
112 size_t outgoing_bytes_buffer_size;
113 tsi_ssl_handshaker_factory* factory_ref;
114 } tsi_ssl_handshaker;
115
116 typedef struct {
117 tsi_handshaker_result base;
118 SSL* ssl;
119 BIO* network_io;
120 unsigned char* unused_bytes;
121 size_t unused_bytes_size;
122 } tsi_ssl_handshaker_result;
123
124 typedef struct {
125 tsi_frame_protector base;
126 SSL* ssl;
127 BIO* network_io;
128 unsigned char* buffer;
129 size_t buffer_size;
130 size_t buffer_offset;
131 } tsi_ssl_frame_protector;
132
133 /* --- Library Initialization. ---*/
134
135 static gpr_once g_init_openssl_once = GPR_ONCE_INIT;
136 static int g_ssl_ctx_ex_factory_index = -1;
137 static const unsigned char kSslSessionIdContext[] = {'g', 'r', 'p', 'c'};
138
139 #if OPENSSL_VERSION_NUMBER < 0x10100000
140 static gpr_mu* g_openssl_mutexes = nullptr;
141 static void openssl_locking_cb(int mode, int type, const char* file,
142 int line) GRPC_UNUSED;
143 static unsigned long openssl_thread_id_cb(void) GRPC_UNUSED;
144
openssl_locking_cb(int mode,int type,const char * file,int line)145 static void openssl_locking_cb(int mode, int type, const char* file, int line) {
146 if (mode & CRYPTO_LOCK) {
147 gpr_mu_lock(&g_openssl_mutexes[type]);
148 } else {
149 gpr_mu_unlock(&g_openssl_mutexes[type]);
150 }
151 }
152
openssl_thread_id_cb(void)153 static unsigned long openssl_thread_id_cb(void) {
154 return static_cast<unsigned long>(gpr_thd_currentid());
155 }
156 #endif
157
init_openssl(void)158 static void init_openssl(void) {
159 SSL_library_init();
160 SSL_load_error_strings();
161 OpenSSL_add_all_algorithms();
162 #if OPENSSL_VERSION_NUMBER < 0x10100000
163 if (!CRYPTO_get_locking_callback()) {
164 int num_locks = CRYPTO_num_locks();
165 GPR_ASSERT(num_locks > 0);
166 g_openssl_mutexes = static_cast<gpr_mu*>(
167 gpr_malloc(static_cast<size_t>(num_locks) * sizeof(gpr_mu)));
168 for (int i = 0; i < num_locks; i++) {
169 gpr_mu_init(&g_openssl_mutexes[i]);
170 }
171 CRYPTO_set_locking_callback(openssl_locking_cb);
172 CRYPTO_set_id_callback(openssl_thread_id_cb);
173 } else {
174 gpr_log(GPR_INFO, "OpenSSL callback has already been set.");
175 }
176 #endif
177 g_ssl_ctx_ex_factory_index =
178 SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr);
179 GPR_ASSERT(g_ssl_ctx_ex_factory_index != -1);
180 }
181
182 /* --- Ssl utils. ---*/
183
ssl_error_string(int error)184 static const char* ssl_error_string(int error) {
185 switch (error) {
186 case SSL_ERROR_NONE:
187 return "SSL_ERROR_NONE";
188 case SSL_ERROR_ZERO_RETURN:
189 return "SSL_ERROR_ZERO_RETURN";
190 case SSL_ERROR_WANT_READ:
191 return "SSL_ERROR_WANT_READ";
192 case SSL_ERROR_WANT_WRITE:
193 return "SSL_ERROR_WANT_WRITE";
194 case SSL_ERROR_WANT_CONNECT:
195 return "SSL_ERROR_WANT_CONNECT";
196 case SSL_ERROR_WANT_ACCEPT:
197 return "SSL_ERROR_WANT_ACCEPT";
198 case SSL_ERROR_WANT_X509_LOOKUP:
199 return "SSL_ERROR_WANT_X509_LOOKUP";
200 case SSL_ERROR_SYSCALL:
201 return "SSL_ERROR_SYSCALL";
202 case SSL_ERROR_SSL:
203 return "SSL_ERROR_SSL";
204 default:
205 return "Unknown error";
206 }
207 }
208
209 /* TODO(jboeuf): Remove when we are past the debugging phase with this code. */
ssl_log_where_info(const SSL * ssl,int where,int flag,const char * msg)210 static void ssl_log_where_info(const SSL* ssl, int where, int flag,
211 const char* msg) {
212 if ((where & flag) && tsi_tracing_enabled.enabled()) {
213 gpr_log(GPR_INFO, "%20.20s - %30.30s - %5.10s", msg,
214 SSL_state_string_long(ssl), SSL_state_string(ssl));
215 }
216 }
217
218 /* Used for debugging. TODO(jboeuf): Remove when code is mature enough. */
ssl_info_callback(const SSL * ssl,int where,int ret)219 static void ssl_info_callback(const SSL* ssl, int where, int ret) {
220 if (ret == 0) {
221 gpr_log(GPR_ERROR, "ssl_info_callback: error occurred.\n");
222 return;
223 }
224
225 ssl_log_where_info(ssl, where, SSL_CB_LOOP, "LOOP");
226 ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_START, "HANDSHAKE START");
227 ssl_log_where_info(ssl, where, SSL_CB_HANDSHAKE_DONE, "HANDSHAKE DONE");
228 }
229
230 /* Returns 1 if name looks like an IP address, 0 otherwise.
231 This is a very rough heuristic, and only handles IPv6 in hexadecimal form. */
looks_like_ip_address(const char * name)232 static int looks_like_ip_address(const char* name) {
233 size_t i;
234 size_t dot_count = 0;
235 size_t num_size = 0;
236 for (i = 0; i < strlen(name); i++) {
237 if (name[i] == ':') {
238 /* IPv6 Address in hexadecimal form, : is not allowed in DNS names. */
239 return 1;
240 }
241 if (name[i] >= '0' && name[i] <= '9') {
242 if (num_size > 3) return 0;
243 num_size++;
244 } else if (name[i] == '.') {
245 if (dot_count > 3 || num_size == 0) return 0;
246 dot_count++;
247 num_size = 0;
248 } else {
249 return 0;
250 }
251 }
252 if (dot_count < 3 || num_size == 0) return 0;
253 return 1;
254 }
255
256 /* Gets the subject CN from an X509 cert. */
ssl_get_x509_common_name(X509 * cert,unsigned char ** utf8,size_t * utf8_size)257 static tsi_result ssl_get_x509_common_name(X509* cert, unsigned char** utf8,
258 size_t* utf8_size) {
259 int common_name_index = -1;
260 X509_NAME_ENTRY* common_name_entry = nullptr;
261 ASN1_STRING* common_name_asn1 = nullptr;
262 X509_NAME* subject_name = X509_get_subject_name(cert);
263 int utf8_returned_size = 0;
264 if (subject_name == nullptr) {
265 gpr_log(GPR_INFO, "Could not get subject name from certificate.");
266 return TSI_NOT_FOUND;
267 }
268 common_name_index =
269 X509_NAME_get_index_by_NID(subject_name, NID_commonName, -1);
270 if (common_name_index == -1) {
271 gpr_log(GPR_INFO, "Could not get common name of subject from certificate.");
272 return TSI_NOT_FOUND;
273 }
274 common_name_entry = X509_NAME_get_entry(subject_name, common_name_index);
275 if (common_name_entry == nullptr) {
276 gpr_log(GPR_ERROR, "Could not get common name entry from certificate.");
277 return TSI_INTERNAL_ERROR;
278 }
279 common_name_asn1 = X509_NAME_ENTRY_get_data(common_name_entry);
280 if (common_name_asn1 == nullptr) {
281 gpr_log(GPR_ERROR,
282 "Could not get common name entry asn1 from certificate.");
283 return TSI_INTERNAL_ERROR;
284 }
285 utf8_returned_size = ASN1_STRING_to_UTF8(utf8, common_name_asn1);
286 if (utf8_returned_size < 0) {
287 gpr_log(GPR_ERROR, "Could not extract utf8 from asn1 string.");
288 return TSI_OUT_OF_RESOURCES;
289 }
290 *utf8_size = static_cast<size_t>(utf8_returned_size);
291 return TSI_OK;
292 }
293
294 /* Gets the subject CN of an X509 cert as a tsi_peer_property. */
peer_property_from_x509_common_name(X509 * cert,tsi_peer_property * property)295 static tsi_result peer_property_from_x509_common_name(
296 X509* cert, tsi_peer_property* property) {
297 unsigned char* common_name;
298 size_t common_name_size;
299 tsi_result result =
300 ssl_get_x509_common_name(cert, &common_name, &common_name_size);
301 if (result != TSI_OK) {
302 if (result == TSI_NOT_FOUND) {
303 common_name = nullptr;
304 common_name_size = 0;
305 } else {
306 return result;
307 }
308 }
309 result = tsi_construct_string_peer_property(
310 TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY,
311 common_name == nullptr ? "" : reinterpret_cast<const char*>(common_name),
312 common_name_size, property);
313 OPENSSL_free(common_name);
314 return result;
315 }
316
317 /* Gets the X509 cert in PEM format as a tsi_peer_property. */
add_pem_certificate(X509 * cert,tsi_peer_property * property)318 static tsi_result add_pem_certificate(X509* cert, tsi_peer_property* property) {
319 BIO* bio = BIO_new(BIO_s_mem());
320 if (!PEM_write_bio_X509(bio, cert)) {
321 BIO_free(bio);
322 return TSI_INTERNAL_ERROR;
323 }
324 char* contents;
325 long len = BIO_get_mem_data(bio, &contents);
326 if (len <= 0) {
327 BIO_free(bio);
328 return TSI_INTERNAL_ERROR;
329 }
330 tsi_result result = tsi_construct_string_peer_property(
331 TSI_X509_PEM_CERT_PROPERTY, (const char*)contents,
332 static_cast<size_t>(len), property);
333 BIO_free(bio);
334 return result;
335 }
336
337 /* Gets the subject SANs from an X509 cert as a tsi_peer_property. */
add_subject_alt_names_properties_to_peer(tsi_peer * peer,GENERAL_NAMES * subject_alt_names,size_t subject_alt_name_count)338 static tsi_result add_subject_alt_names_properties_to_peer(
339 tsi_peer* peer, GENERAL_NAMES* subject_alt_names,
340 size_t subject_alt_name_count) {
341 size_t i;
342 tsi_result result = TSI_OK;
343
344 /* Reset for DNS entries filtering. */
345 peer->property_count -= subject_alt_name_count;
346
347 for (i = 0; i < subject_alt_name_count; i++) {
348 GENERAL_NAME* subject_alt_name =
349 sk_GENERAL_NAME_value(subject_alt_names, TSI_SIZE_AS_SIZE(i));
350 /* Filter out the non-dns entries names. */
351 if (subject_alt_name->type == GEN_DNS) {
352 unsigned char* name = nullptr;
353 int name_size;
354 name_size = ASN1_STRING_to_UTF8(&name, subject_alt_name->d.dNSName);
355 if (name_size < 0) {
356 gpr_log(GPR_ERROR, "Could not get utf8 from asn1 string.");
357 result = TSI_INTERNAL_ERROR;
358 break;
359 }
360 result = tsi_construct_string_peer_property(
361 TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY,
362 reinterpret_cast<const char*>(name), static_cast<size_t>(name_size),
363 &peer->properties[peer->property_count++]);
364 OPENSSL_free(name);
365 } else if (subject_alt_name->type == GEN_IPADD) {
366 char ntop_buf[INET6_ADDRSTRLEN];
367 int af;
368
369 if (subject_alt_name->d.iPAddress->length == 4) {
370 af = AF_INET;
371 } else if (subject_alt_name->d.iPAddress->length == 16) {
372 af = AF_INET6;
373 } else {
374 gpr_log(GPR_ERROR, "SAN IP Address contained invalid IP");
375 result = TSI_INTERNAL_ERROR;
376 break;
377 }
378 const char* name = inet_ntop(af, subject_alt_name->d.iPAddress->data,
379 ntop_buf, INET6_ADDRSTRLEN);
380 if (name == nullptr) {
381 gpr_log(GPR_ERROR, "Could not get IP string from asn1 octet.");
382 result = TSI_INTERNAL_ERROR;
383 break;
384 }
385
386 result = tsi_construct_string_peer_property_from_cstring(
387 TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY, name,
388 &peer->properties[peer->property_count++]);
389 }
390 if (result != TSI_OK) break;
391 }
392 return result;
393 }
394
395 /* Gets information about the peer's X509 cert as a tsi_peer object. */
peer_from_x509(X509 * cert,int include_certificate_type,tsi_peer * peer)396 static tsi_result peer_from_x509(X509* cert, int include_certificate_type,
397 tsi_peer* peer) {
398 /* TODO(jboeuf): Maybe add more properties. */
399 GENERAL_NAMES* subject_alt_names = static_cast<GENERAL_NAMES*>(
400 X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
401 int subject_alt_name_count =
402 (subject_alt_names != nullptr)
403 ? static_cast<int>(sk_GENERAL_NAME_num(subject_alt_names))
404 : 0;
405 size_t property_count;
406 tsi_result result;
407 GPR_ASSERT(subject_alt_name_count >= 0);
408 property_count = (include_certificate_type ? static_cast<size_t>(1) : 0) +
409 2 /* common name, certificate */ +
410 static_cast<size_t>(subject_alt_name_count);
411 result = tsi_construct_peer(property_count, peer);
412 if (result != TSI_OK) return result;
413 do {
414 if (include_certificate_type) {
415 result = tsi_construct_string_peer_property_from_cstring(
416 TSI_CERTIFICATE_TYPE_PEER_PROPERTY, TSI_X509_CERTIFICATE_TYPE,
417 &peer->properties[0]);
418 if (result != TSI_OK) break;
419 }
420 result = peer_property_from_x509_common_name(
421 cert, &peer->properties[include_certificate_type ? 1 : 0]);
422 if (result != TSI_OK) break;
423
424 result = add_pem_certificate(
425 cert, &peer->properties[include_certificate_type ? 2 : 1]);
426 if (result != TSI_OK) break;
427
428 if (subject_alt_name_count != 0) {
429 result = add_subject_alt_names_properties_to_peer(
430 peer, subject_alt_names, static_cast<size_t>(subject_alt_name_count));
431 if (result != TSI_OK) break;
432 }
433 } while (0);
434
435 if (subject_alt_names != nullptr) {
436 sk_GENERAL_NAME_pop_free(subject_alt_names, GENERAL_NAME_free);
437 }
438 if (result != TSI_OK) tsi_peer_destruct(peer);
439 return result;
440 }
441
442 /* Logs the SSL error stack. */
log_ssl_error_stack(void)443 static void log_ssl_error_stack(void) {
444 unsigned long err;
445 while ((err = ERR_get_error()) != 0) {
446 char details[256];
447 ERR_error_string_n(static_cast<uint32_t>(err), details, sizeof(details));
448 gpr_log(GPR_ERROR, "%s", details);
449 }
450 }
451
452 /* Performs an SSL_read and handle errors. */
do_ssl_read(SSL * ssl,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)453 static tsi_result do_ssl_read(SSL* ssl, unsigned char* unprotected_bytes,
454 size_t* unprotected_bytes_size) {
455 int read_from_ssl;
456 GPR_ASSERT(*unprotected_bytes_size <= INT_MAX);
457 read_from_ssl = SSL_read(ssl, unprotected_bytes,
458 static_cast<int>(*unprotected_bytes_size));
459 if (read_from_ssl <= 0) {
460 read_from_ssl = SSL_get_error(ssl, read_from_ssl);
461 switch (read_from_ssl) {
462 case SSL_ERROR_ZERO_RETURN: /* Received a close_notify alert. */
463 case SSL_ERROR_WANT_READ: /* We need more data to finish the frame. */
464 *unprotected_bytes_size = 0;
465 return TSI_OK;
466 case SSL_ERROR_WANT_WRITE:
467 gpr_log(
468 GPR_ERROR,
469 "Peer tried to renegotiate SSL connection. This is unsupported.");
470 return TSI_UNIMPLEMENTED;
471 case SSL_ERROR_SSL:
472 gpr_log(GPR_ERROR, "Corruption detected.");
473 log_ssl_error_stack();
474 return TSI_DATA_CORRUPTED;
475 default:
476 gpr_log(GPR_ERROR, "SSL_read failed with error %s.",
477 ssl_error_string(read_from_ssl));
478 return TSI_PROTOCOL_FAILURE;
479 }
480 }
481 *unprotected_bytes_size = static_cast<size_t>(read_from_ssl);
482 return TSI_OK;
483 }
484
485 /* Performs an SSL_write and handle errors. */
do_ssl_write(SSL * ssl,unsigned char * unprotected_bytes,size_t unprotected_bytes_size)486 static tsi_result do_ssl_write(SSL* ssl, unsigned char* unprotected_bytes,
487 size_t unprotected_bytes_size) {
488 int ssl_write_result;
489 GPR_ASSERT(unprotected_bytes_size <= INT_MAX);
490 ssl_write_result = SSL_write(ssl, unprotected_bytes,
491 static_cast<int>(unprotected_bytes_size));
492 if (ssl_write_result < 0) {
493 ssl_write_result = SSL_get_error(ssl, ssl_write_result);
494 if (ssl_write_result == SSL_ERROR_WANT_READ) {
495 gpr_log(GPR_ERROR,
496 "Peer tried to renegotiate SSL connection. This is unsupported.");
497 return TSI_UNIMPLEMENTED;
498 } else {
499 gpr_log(GPR_ERROR, "SSL_write failed with error %s.",
500 ssl_error_string(ssl_write_result));
501 return TSI_INTERNAL_ERROR;
502 }
503 }
504 return TSI_OK;
505 }
506
507 /* Loads an in-memory PEM certificate chain into the SSL context. */
ssl_ctx_use_certificate_chain(SSL_CTX * context,const char * pem_cert_chain,size_t pem_cert_chain_size)508 static tsi_result ssl_ctx_use_certificate_chain(SSL_CTX* context,
509 const char* pem_cert_chain,
510 size_t pem_cert_chain_size) {
511 tsi_result result = TSI_OK;
512 X509* certificate = nullptr;
513 BIO* pem;
514 GPR_ASSERT(pem_cert_chain_size <= INT_MAX);
515 pem = BIO_new_mem_buf((void*)pem_cert_chain,
516 static_cast<int>(pem_cert_chain_size));
517 if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
518
519 do {
520 certificate = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, (void*)"");
521 if (certificate == nullptr) {
522 result = TSI_INVALID_ARGUMENT;
523 break;
524 }
525 if (!SSL_CTX_use_certificate(context, certificate)) {
526 result = TSI_INVALID_ARGUMENT;
527 break;
528 }
529 while (1) {
530 X509* certificate_authority =
531 PEM_read_bio_X509(pem, nullptr, nullptr, (void*)"");
532 if (certificate_authority == nullptr) {
533 ERR_clear_error();
534 break; /* Done reading. */
535 }
536 if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) {
537 X509_free(certificate_authority);
538 result = TSI_INVALID_ARGUMENT;
539 break;
540 }
541 /* We don't need to free certificate_authority as its ownership has been
542 transfered to the context. That is not the case for certificate though.
543 */
544 }
545 } while (0);
546
547 if (certificate != nullptr) X509_free(certificate);
548 BIO_free(pem);
549 return result;
550 }
551
552 /* Loads an in-memory PEM private key into the SSL context. */
ssl_ctx_use_private_key(SSL_CTX * context,const char * pem_key,size_t pem_key_size)553 static tsi_result ssl_ctx_use_private_key(SSL_CTX* context, const char* pem_key,
554 size_t pem_key_size) {
555 tsi_result result = TSI_OK;
556 EVP_PKEY* private_key = nullptr;
557 BIO* pem;
558 GPR_ASSERT(pem_key_size <= INT_MAX);
559 pem = BIO_new_mem_buf((void*)pem_key, static_cast<int>(pem_key_size));
560 if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
561 do {
562 private_key = PEM_read_bio_PrivateKey(pem, nullptr, nullptr, (void*)"");
563 if (private_key == nullptr) {
564 result = TSI_INVALID_ARGUMENT;
565 break;
566 }
567 if (!SSL_CTX_use_PrivateKey(context, private_key)) {
568 result = TSI_INVALID_ARGUMENT;
569 break;
570 }
571 } while (0);
572 if (private_key != nullptr) EVP_PKEY_free(private_key);
573 BIO_free(pem);
574 return result;
575 }
576
577 /* Loads in-memory PEM verification certs into the SSL context and optionally
578 returns the verification cert names (root_names can be NULL). */
x509_store_load_certs(X509_STORE * cert_store,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_names)579 static tsi_result x509_store_load_certs(X509_STORE* cert_store,
580 const char* pem_roots,
581 size_t pem_roots_size,
582 STACK_OF(X509_NAME) * *root_names) {
583 tsi_result result = TSI_OK;
584 size_t num_roots = 0;
585 X509* root = nullptr;
586 X509_NAME* root_name = nullptr;
587 BIO* pem;
588 GPR_ASSERT(pem_roots_size <= INT_MAX);
589 pem = BIO_new_mem_buf((void*)pem_roots, static_cast<int>(pem_roots_size));
590 if (cert_store == nullptr) return TSI_INVALID_ARGUMENT;
591 if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
592 if (root_names != nullptr) {
593 *root_names = sk_X509_NAME_new_null();
594 if (*root_names == nullptr) return TSI_OUT_OF_RESOURCES;
595 }
596
597 while (1) {
598 root = PEM_read_bio_X509_AUX(pem, nullptr, nullptr, (void*)"");
599 if (root == nullptr) {
600 ERR_clear_error();
601 break; /* We're at the end of stream. */
602 }
603 if (root_names != nullptr) {
604 root_name = X509_get_subject_name(root);
605 if (root_name == nullptr) {
606 gpr_log(GPR_ERROR, "Could not get name from root certificate.");
607 result = TSI_INVALID_ARGUMENT;
608 break;
609 }
610 root_name = X509_NAME_dup(root_name);
611 if (root_name == nullptr) {
612 result = TSI_OUT_OF_RESOURCES;
613 break;
614 }
615 sk_X509_NAME_push(*root_names, root_name);
616 root_name = nullptr;
617 }
618 if (!X509_STORE_add_cert(cert_store, root)) {
619 gpr_log(GPR_ERROR, "Could not add root certificate to ssl context.");
620 result = TSI_INTERNAL_ERROR;
621 break;
622 }
623 X509_free(root);
624 num_roots++;
625 }
626
627 if (num_roots == 0) {
628 gpr_log(GPR_ERROR, "Could not load any root certificate.");
629 result = TSI_INVALID_ARGUMENT;
630 }
631
632 if (result != TSI_OK) {
633 if (root != nullptr) X509_free(root);
634 if (root_names != nullptr) {
635 sk_X509_NAME_pop_free(*root_names, X509_NAME_free);
636 *root_names = nullptr;
637 if (root_name != nullptr) X509_NAME_free(root_name);
638 }
639 }
640 BIO_free(pem);
641 return result;
642 }
643
ssl_ctx_load_verification_certs(SSL_CTX * context,const char * pem_roots,size_t pem_roots_size,STACK_OF (X509_NAME)** root_name)644 static tsi_result ssl_ctx_load_verification_certs(SSL_CTX* context,
645 const char* pem_roots,
646 size_t pem_roots_size,
647 STACK_OF(X509_NAME) *
648 *root_name) {
649 X509_STORE* cert_store = SSL_CTX_get_cert_store(context);
650 return x509_store_load_certs(cert_store, pem_roots, pem_roots_size,
651 root_name);
652 }
653
654 /* Populates the SSL context with a private key and a cert chain, and sets the
655 cipher list and the ephemeral ECDH key. */
populate_ssl_context(SSL_CTX * context,const tsi_ssl_pem_key_cert_pair * key_cert_pair,const char * cipher_list)656 static tsi_result populate_ssl_context(
657 SSL_CTX* context, const tsi_ssl_pem_key_cert_pair* key_cert_pair,
658 const char* cipher_list) {
659 tsi_result result = TSI_OK;
660 if (key_cert_pair != nullptr) {
661 if (key_cert_pair->cert_chain != nullptr) {
662 result = ssl_ctx_use_certificate_chain(context, key_cert_pair->cert_chain,
663 strlen(key_cert_pair->cert_chain));
664 if (result != TSI_OK) {
665 gpr_log(GPR_ERROR, "Invalid cert chain file.");
666 return result;
667 }
668 }
669 if (key_cert_pair->private_key != nullptr) {
670 result = ssl_ctx_use_private_key(context, key_cert_pair->private_key,
671 strlen(key_cert_pair->private_key));
672 if (result != TSI_OK || !SSL_CTX_check_private_key(context)) {
673 gpr_log(GPR_ERROR, "Invalid private key.");
674 return result != TSI_OK ? result : TSI_INVALID_ARGUMENT;
675 }
676 }
677 }
678 if ((cipher_list != nullptr) &&
679 !SSL_CTX_set_cipher_list(context, cipher_list)) {
680 gpr_log(GPR_ERROR, "Invalid cipher list: %s.", cipher_list);
681 return TSI_INVALID_ARGUMENT;
682 }
683 {
684 EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
685 if (!SSL_CTX_set_tmp_ecdh(context, ecdh)) {
686 gpr_log(GPR_ERROR, "Could not set ephemeral ECDH key.");
687 EC_KEY_free(ecdh);
688 return TSI_INTERNAL_ERROR;
689 }
690 SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
691 EC_KEY_free(ecdh);
692 }
693 return TSI_OK;
694 }
695
696 /* Extracts the CN and the SANs from an X509 cert as a peer object. */
extract_x509_subject_names_from_pem_cert(const char * pem_cert,tsi_peer * peer)697 static tsi_result extract_x509_subject_names_from_pem_cert(const char* pem_cert,
698 tsi_peer* peer) {
699 tsi_result result = TSI_OK;
700 X509* cert = nullptr;
701 BIO* pem;
702 pem = BIO_new_mem_buf((void*)pem_cert, static_cast<int>(strlen(pem_cert)));
703 if (pem == nullptr) return TSI_OUT_OF_RESOURCES;
704
705 cert = PEM_read_bio_X509(pem, nullptr, nullptr, (void*)"");
706 if (cert == nullptr) {
707 gpr_log(GPR_ERROR, "Invalid certificate");
708 result = TSI_INVALID_ARGUMENT;
709 } else {
710 result = peer_from_x509(cert, 0, peer);
711 }
712 if (cert != nullptr) X509_free(cert);
713 BIO_free(pem);
714 return result;
715 }
716
717 /* Builds the alpn protocol name list according to rfc 7301. */
build_alpn_protocol_name_list(const char ** alpn_protocols,uint16_t num_alpn_protocols,unsigned char ** protocol_name_list,size_t * protocol_name_list_length)718 static tsi_result build_alpn_protocol_name_list(
719 const char** alpn_protocols, uint16_t num_alpn_protocols,
720 unsigned char** protocol_name_list, size_t* protocol_name_list_length) {
721 uint16_t i;
722 unsigned char* current;
723 *protocol_name_list = nullptr;
724 *protocol_name_list_length = 0;
725 if (num_alpn_protocols == 0) return TSI_INVALID_ARGUMENT;
726 for (i = 0; i < num_alpn_protocols; i++) {
727 size_t length =
728 alpn_protocols[i] == nullptr ? 0 : strlen(alpn_protocols[i]);
729 if (length == 0 || length > 255) {
730 gpr_log(GPR_ERROR, "Invalid protocol name length: %d.",
731 static_cast<int>(length));
732 return TSI_INVALID_ARGUMENT;
733 }
734 *protocol_name_list_length += length + 1;
735 }
736 *protocol_name_list =
737 static_cast<unsigned char*>(gpr_malloc(*protocol_name_list_length));
738 if (*protocol_name_list == nullptr) return TSI_OUT_OF_RESOURCES;
739 current = *protocol_name_list;
740 for (i = 0; i < num_alpn_protocols; i++) {
741 size_t length = strlen(alpn_protocols[i]);
742 *(current++) = static_cast<uint8_t>(length); /* max checked above. */
743 memcpy(current, alpn_protocols[i], length);
744 current += length;
745 }
746 /* Safety check. */
747 if ((current < *protocol_name_list) ||
748 (static_cast<uintptr_t>(current - *protocol_name_list) !=
749 *protocol_name_list_length)) {
750 return TSI_INTERNAL_ERROR;
751 }
752 return TSI_OK;
753 }
754
755 // The verification callback is used for clients that don't really care about
756 // the server's certificate, but we need to pull it anyway, in case a higher
757 // layer wants to look at it. In this case the verification may fail, but
758 // we don't really care.
NullVerifyCallback(int preverify_ok,X509_STORE_CTX * ctx)759 static int NullVerifyCallback(int preverify_ok, X509_STORE_CTX* ctx) {
760 return 1;
761 }
762
763 /* --- tsi_ssl_root_certs_store methods implementation. ---*/
764
tsi_ssl_root_certs_store_create(const char * pem_roots)765 tsi_ssl_root_certs_store* tsi_ssl_root_certs_store_create(
766 const char* pem_roots) {
767 if (pem_roots == nullptr) {
768 gpr_log(GPR_ERROR, "The root certificates are empty.");
769 return nullptr;
770 }
771 tsi_ssl_root_certs_store* root_store = static_cast<tsi_ssl_root_certs_store*>(
772 gpr_zalloc(sizeof(tsi_ssl_root_certs_store)));
773 if (root_store == nullptr) {
774 gpr_log(GPR_ERROR, "Could not allocate buffer for ssl_root_certs_store.");
775 return nullptr;
776 }
777 root_store->store = X509_STORE_new();
778 if (root_store->store == nullptr) {
779 gpr_log(GPR_ERROR, "Could not allocate buffer for X509_STORE.");
780 gpr_free(root_store);
781 return nullptr;
782 }
783 tsi_result result = x509_store_load_certs(root_store->store, pem_roots,
784 strlen(pem_roots), nullptr);
785 if (result != TSI_OK) {
786 gpr_log(GPR_ERROR, "Could not load root certificates.");
787 X509_STORE_free(root_store->store);
788 gpr_free(root_store);
789 return nullptr;
790 }
791 return root_store;
792 }
793
tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store * self)794 void tsi_ssl_root_certs_store_destroy(tsi_ssl_root_certs_store* self) {
795 if (self == nullptr) return;
796 X509_STORE_free(self->store);
797 gpr_free(self);
798 }
799
800 /* --- tsi_ssl_session_cache methods implementation. ---*/
801
tsi_ssl_session_cache_create_lru(size_t capacity)802 tsi_ssl_session_cache* tsi_ssl_session_cache_create_lru(size_t capacity) {
803 /* Pointer will be dereferenced by unref call. */
804 return reinterpret_cast<tsi_ssl_session_cache*>(
805 tsi::SslSessionLRUCache::Create(capacity).release());
806 }
807
tsi_ssl_session_cache_ref(tsi_ssl_session_cache * cache)808 void tsi_ssl_session_cache_ref(tsi_ssl_session_cache* cache) {
809 /* Pointer will be dereferenced by unref call. */
810 reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Ref().release();
811 }
812
tsi_ssl_session_cache_unref(tsi_ssl_session_cache * cache)813 void tsi_ssl_session_cache_unref(tsi_ssl_session_cache* cache) {
814 reinterpret_cast<tsi::SslSessionLRUCache*>(cache)->Unref();
815 }
816
817 /* --- tsi_frame_protector methods implementation. ---*/
818
ssl_protector_protect(tsi_frame_protector * self,const unsigned char * unprotected_bytes,size_t * unprotected_bytes_size,unsigned char * protected_output_frames,size_t * protected_output_frames_size)819 static tsi_result ssl_protector_protect(tsi_frame_protector* self,
820 const unsigned char* unprotected_bytes,
821 size_t* unprotected_bytes_size,
822 unsigned char* protected_output_frames,
823 size_t* protected_output_frames_size) {
824 tsi_ssl_frame_protector* impl =
825 reinterpret_cast<tsi_ssl_frame_protector*>(self);
826 int read_from_ssl;
827 size_t available;
828 tsi_result result = TSI_OK;
829
830 /* First see if we have some pending data in the SSL BIO. */
831 int pending_in_ssl = static_cast<int>(BIO_pending(impl->network_io));
832 if (pending_in_ssl > 0) {
833 *unprotected_bytes_size = 0;
834 GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
835 read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
836 static_cast<int>(*protected_output_frames_size));
837 if (read_from_ssl < 0) {
838 gpr_log(GPR_ERROR,
839 "Could not read from BIO even though some data is pending");
840 return TSI_INTERNAL_ERROR;
841 }
842 *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
843 return TSI_OK;
844 }
845
846 /* Now see if we can send a complete frame. */
847 available = impl->buffer_size - impl->buffer_offset;
848 if (available > *unprotected_bytes_size) {
849 /* If we cannot, just copy the data in our internal buffer. */
850 memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes,
851 *unprotected_bytes_size);
852 impl->buffer_offset += *unprotected_bytes_size;
853 *protected_output_frames_size = 0;
854 return TSI_OK;
855 }
856
857 /* If we can, prepare the buffer, send it to SSL_write and read. */
858 memcpy(impl->buffer + impl->buffer_offset, unprotected_bytes, available);
859 result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_size);
860 if (result != TSI_OK) return result;
861
862 GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
863 read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
864 static_cast<int>(*protected_output_frames_size));
865 if (read_from_ssl < 0) {
866 gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write.");
867 return TSI_INTERNAL_ERROR;
868 }
869 *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
870 *unprotected_bytes_size = available;
871 impl->buffer_offset = 0;
872 return TSI_OK;
873 }
874
ssl_protector_protect_flush(tsi_frame_protector * self,unsigned char * protected_output_frames,size_t * protected_output_frames_size,size_t * still_pending_size)875 static tsi_result ssl_protector_protect_flush(
876 tsi_frame_protector* self, unsigned char* protected_output_frames,
877 size_t* protected_output_frames_size, size_t* still_pending_size) {
878 tsi_result result = TSI_OK;
879 tsi_ssl_frame_protector* impl =
880 reinterpret_cast<tsi_ssl_frame_protector*>(self);
881 int read_from_ssl = 0;
882 int pending;
883
884 if (impl->buffer_offset != 0) {
885 result = do_ssl_write(impl->ssl, impl->buffer, impl->buffer_offset);
886 if (result != TSI_OK) return result;
887 impl->buffer_offset = 0;
888 }
889
890 pending = static_cast<int>(BIO_pending(impl->network_io));
891 GPR_ASSERT(pending >= 0);
892 *still_pending_size = static_cast<size_t>(pending);
893 if (*still_pending_size == 0) return TSI_OK;
894
895 GPR_ASSERT(*protected_output_frames_size <= INT_MAX);
896 read_from_ssl = BIO_read(impl->network_io, protected_output_frames,
897 static_cast<int>(*protected_output_frames_size));
898 if (read_from_ssl <= 0) {
899 gpr_log(GPR_ERROR, "Could not read from BIO after SSL_write.");
900 return TSI_INTERNAL_ERROR;
901 }
902 *protected_output_frames_size = static_cast<size_t>(read_from_ssl);
903 pending = static_cast<int>(BIO_pending(impl->network_io));
904 GPR_ASSERT(pending >= 0);
905 *still_pending_size = static_cast<size_t>(pending);
906 return TSI_OK;
907 }
908
ssl_protector_unprotect(tsi_frame_protector * self,const unsigned char * protected_frames_bytes,size_t * protected_frames_bytes_size,unsigned char * unprotected_bytes,size_t * unprotected_bytes_size)909 static tsi_result ssl_protector_unprotect(
910 tsi_frame_protector* self, const unsigned char* protected_frames_bytes,
911 size_t* protected_frames_bytes_size, unsigned char* unprotected_bytes,
912 size_t* unprotected_bytes_size) {
913 tsi_result result = TSI_OK;
914 int written_into_ssl = 0;
915 size_t output_bytes_size = *unprotected_bytes_size;
916 size_t output_bytes_offset = 0;
917 tsi_ssl_frame_protector* impl =
918 reinterpret_cast<tsi_ssl_frame_protector*>(self);
919
920 /* First, try to read remaining data from ssl. */
921 result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size);
922 if (result != TSI_OK) return result;
923 if (*unprotected_bytes_size == output_bytes_size) {
924 /* We have read everything we could and cannot process any more input. */
925 *protected_frames_bytes_size = 0;
926 return TSI_OK;
927 }
928 output_bytes_offset = *unprotected_bytes_size;
929 unprotected_bytes += output_bytes_offset;
930 *unprotected_bytes_size = output_bytes_size - output_bytes_offset;
931
932 /* Then, try to write some data to ssl. */
933 GPR_ASSERT(*protected_frames_bytes_size <= INT_MAX);
934 written_into_ssl = BIO_write(impl->network_io, protected_frames_bytes,
935 static_cast<int>(*protected_frames_bytes_size));
936 if (written_into_ssl < 0) {
937 gpr_log(GPR_ERROR, "Sending protected frame to ssl failed with %d",
938 written_into_ssl);
939 return TSI_INTERNAL_ERROR;
940 }
941 *protected_frames_bytes_size = static_cast<size_t>(written_into_ssl);
942
943 /* Now try to read some data again. */
944 result = do_ssl_read(impl->ssl, unprotected_bytes, unprotected_bytes_size);
945 if (result == TSI_OK) {
946 /* Don't forget to output the total number of bytes read. */
947 *unprotected_bytes_size += output_bytes_offset;
948 }
949 return result;
950 }
951
ssl_protector_destroy(tsi_frame_protector * self)952 static void ssl_protector_destroy(tsi_frame_protector* self) {
953 tsi_ssl_frame_protector* impl =
954 reinterpret_cast<tsi_ssl_frame_protector*>(self);
955 if (impl->buffer != nullptr) gpr_free(impl->buffer);
956 if (impl->ssl != nullptr) SSL_free(impl->ssl);
957 if (impl->network_io != nullptr) BIO_free(impl->network_io);
958 gpr_free(self);
959 }
960
961 static const tsi_frame_protector_vtable frame_protector_vtable = {
962 ssl_protector_protect,
963 ssl_protector_protect_flush,
964 ssl_protector_unprotect,
965 ssl_protector_destroy,
966 };
967
968 /* --- tsi_server_handshaker_factory methods implementation. --- */
969
tsi_ssl_handshaker_factory_destroy(tsi_ssl_handshaker_factory * self)970 static void tsi_ssl_handshaker_factory_destroy(
971 tsi_ssl_handshaker_factory* self) {
972 if (self == nullptr) return;
973
974 if (self->vtable != nullptr && self->vtable->destroy != nullptr) {
975 self->vtable->destroy(self);
976 }
977 /* Note, we don't free(self) here because this object is always directly
978 * embedded in another object. If tsi_ssl_handshaker_factory_init allocates
979 * any memory, it should be free'd here. */
980 }
981
tsi_ssl_handshaker_factory_ref(tsi_ssl_handshaker_factory * self)982 static tsi_ssl_handshaker_factory* tsi_ssl_handshaker_factory_ref(
983 tsi_ssl_handshaker_factory* self) {
984 if (self == nullptr) return nullptr;
985 gpr_refn(&self->refcount, 1);
986 return self;
987 }
988
tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory * self)989 static void tsi_ssl_handshaker_factory_unref(tsi_ssl_handshaker_factory* self) {
990 if (self == nullptr) return;
991
992 if (gpr_unref(&self->refcount)) {
993 tsi_ssl_handshaker_factory_destroy(self);
994 }
995 }
996
997 static tsi_ssl_handshaker_factory_vtable handshaker_factory_vtable = {nullptr};
998
999 /* Initializes a tsi_ssl_handshaker_factory object. Caller is responsible for
1000 * allocating memory for the factory. */
tsi_ssl_handshaker_factory_init(tsi_ssl_handshaker_factory * factory)1001 static void tsi_ssl_handshaker_factory_init(
1002 tsi_ssl_handshaker_factory* factory) {
1003 GPR_ASSERT(factory != nullptr);
1004
1005 factory->vtable = &handshaker_factory_vtable;
1006 gpr_ref_init(&factory->refcount, 1);
1007 }
1008
1009 /* --- tsi_handshaker_result methods implementation. ---*/
1010
ssl_handshaker_result_extract_peer(const tsi_handshaker_result * self,tsi_peer * peer)1011 static tsi_result ssl_handshaker_result_extract_peer(
1012 const tsi_handshaker_result* self, tsi_peer* peer) {
1013 tsi_result result = TSI_OK;
1014 const unsigned char* alpn_selected = nullptr;
1015 unsigned int alpn_selected_len;
1016 const tsi_ssl_handshaker_result* impl =
1017 reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1018 X509* peer_cert = SSL_get_peer_certificate(impl->ssl);
1019 if (peer_cert != nullptr) {
1020 result = peer_from_x509(peer_cert, 1, peer);
1021 X509_free(peer_cert);
1022 if (result != TSI_OK) return result;
1023 }
1024 #if TSI_OPENSSL_ALPN_SUPPORT
1025 SSL_get0_alpn_selected(impl->ssl, &alpn_selected, &alpn_selected_len);
1026 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1027 if (alpn_selected == nullptr) {
1028 /* Try npn. */
1029 SSL_get0_next_proto_negotiated(impl->ssl, &alpn_selected,
1030 &alpn_selected_len);
1031 }
1032
1033 // 1 is for session reused property.
1034 size_t new_property_count = peer->property_count + 1;
1035 if (alpn_selected != nullptr) new_property_count++;
1036 tsi_peer_property* new_properties = static_cast<tsi_peer_property*>(
1037 gpr_zalloc(sizeof(*new_properties) * new_property_count));
1038 for (size_t i = 0; i < peer->property_count; i++) {
1039 new_properties[i] = peer->properties[i];
1040 }
1041 if (peer->properties != nullptr) gpr_free(peer->properties);
1042 peer->properties = new_properties;
1043
1044 if (alpn_selected != nullptr) {
1045 result = tsi_construct_string_peer_property(
1046 TSI_SSL_ALPN_SELECTED_PROTOCOL,
1047 reinterpret_cast<const char*>(alpn_selected), alpn_selected_len,
1048 &peer->properties[peer->property_count]);
1049 if (result != TSI_OK) return result;
1050 peer->property_count++;
1051 }
1052
1053 const char* session_reused = SSL_session_reused(impl->ssl) ? "true" : "false";
1054 result = tsi_construct_string_peer_property_from_cstring(
1055 TSI_SSL_SESSION_REUSED_PEER_PROPERTY, session_reused,
1056 &peer->properties[peer->property_count]);
1057 if (result != TSI_OK) return result;
1058 peer->property_count++;
1059
1060 return result;
1061 }
1062
ssl_handshaker_result_create_frame_protector(const tsi_handshaker_result * self,size_t * max_output_protected_frame_size,tsi_frame_protector ** protector)1063 static tsi_result ssl_handshaker_result_create_frame_protector(
1064 const tsi_handshaker_result* self, size_t* max_output_protected_frame_size,
1065 tsi_frame_protector** protector) {
1066 size_t actual_max_output_protected_frame_size =
1067 TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1068 tsi_ssl_handshaker_result* impl =
1069 reinterpret_cast<tsi_ssl_handshaker_result*>(
1070 const_cast<tsi_handshaker_result*>(self));
1071 tsi_ssl_frame_protector* protector_impl =
1072 static_cast<tsi_ssl_frame_protector*>(
1073 gpr_zalloc(sizeof(*protector_impl)));
1074
1075 if (max_output_protected_frame_size != nullptr) {
1076 if (*max_output_protected_frame_size >
1077 TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND) {
1078 *max_output_protected_frame_size =
1079 TSI_SSL_MAX_PROTECTED_FRAME_SIZE_UPPER_BOUND;
1080 } else if (*max_output_protected_frame_size <
1081 TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND) {
1082 *max_output_protected_frame_size =
1083 TSI_SSL_MAX_PROTECTED_FRAME_SIZE_LOWER_BOUND;
1084 }
1085 actual_max_output_protected_frame_size = *max_output_protected_frame_size;
1086 }
1087 protector_impl->buffer_size =
1088 actual_max_output_protected_frame_size - TSI_SSL_MAX_PROTECTION_OVERHEAD;
1089 protector_impl->buffer =
1090 static_cast<unsigned char*>(gpr_malloc(protector_impl->buffer_size));
1091 if (protector_impl->buffer == nullptr) {
1092 gpr_log(GPR_ERROR,
1093 "Could not allocated buffer for tsi_ssl_frame_protector.");
1094 gpr_free(protector_impl);
1095 return TSI_INTERNAL_ERROR;
1096 }
1097
1098 /* Transfer ownership of ssl and network_io to the frame protector. */
1099 protector_impl->ssl = impl->ssl;
1100 impl->ssl = nullptr;
1101 protector_impl->network_io = impl->network_io;
1102 impl->network_io = nullptr;
1103 protector_impl->base.vtable = &frame_protector_vtable;
1104 *protector = &protector_impl->base;
1105 return TSI_OK;
1106 }
1107
ssl_handshaker_result_get_unused_bytes(const tsi_handshaker_result * self,const unsigned char ** bytes,size_t * bytes_size)1108 static tsi_result ssl_handshaker_result_get_unused_bytes(
1109 const tsi_handshaker_result* self, const unsigned char** bytes,
1110 size_t* bytes_size) {
1111 const tsi_ssl_handshaker_result* impl =
1112 reinterpret_cast<const tsi_ssl_handshaker_result*>(self);
1113 *bytes_size = impl->unused_bytes_size;
1114 *bytes = impl->unused_bytes;
1115 return TSI_OK;
1116 }
1117
ssl_handshaker_result_destroy(tsi_handshaker_result * self)1118 static void ssl_handshaker_result_destroy(tsi_handshaker_result* self) {
1119 tsi_ssl_handshaker_result* impl =
1120 reinterpret_cast<tsi_ssl_handshaker_result*>(self);
1121 SSL_free(impl->ssl);
1122 BIO_free(impl->network_io);
1123 gpr_free(impl->unused_bytes);
1124 gpr_free(impl);
1125 }
1126
1127 static const tsi_handshaker_result_vtable handshaker_result_vtable = {
1128 ssl_handshaker_result_extract_peer,
1129 nullptr, /* create_zero_copy_grpc_protector */
1130 ssl_handshaker_result_create_frame_protector,
1131 ssl_handshaker_result_get_unused_bytes,
1132 ssl_handshaker_result_destroy,
1133 };
1134
ssl_handshaker_result_create(tsi_ssl_handshaker * handshaker,const unsigned char * unused_bytes,size_t unused_bytes_size,tsi_handshaker_result ** handshaker_result)1135 static tsi_result ssl_handshaker_result_create(
1136 tsi_ssl_handshaker* handshaker, const unsigned char* unused_bytes,
1137 size_t unused_bytes_size, tsi_handshaker_result** handshaker_result) {
1138 if (handshaker == nullptr || handshaker_result == nullptr ||
1139 (unused_bytes_size > 0 && unused_bytes == nullptr)) {
1140 return TSI_INVALID_ARGUMENT;
1141 }
1142 tsi_ssl_handshaker_result* result =
1143 static_cast<tsi_ssl_handshaker_result*>(gpr_zalloc(sizeof(*result)));
1144 result->base.vtable = &handshaker_result_vtable;
1145 /* Transfer ownership of ssl and network_io to the handshaker result. */
1146 result->ssl = handshaker->ssl;
1147 handshaker->ssl = nullptr;
1148 result->network_io = handshaker->network_io;
1149 handshaker->network_io = nullptr;
1150 if (unused_bytes_size > 0) {
1151 result->unused_bytes =
1152 static_cast<unsigned char*>(gpr_malloc(unused_bytes_size));
1153 memcpy(result->unused_bytes, unused_bytes, unused_bytes_size);
1154 }
1155 result->unused_bytes_size = unused_bytes_size;
1156 *handshaker_result = &result->base;
1157 return TSI_OK;
1158 }
1159
1160 /* --- tsi_handshaker methods implementation. ---*/
1161
ssl_handshaker_get_bytes_to_send_to_peer(tsi_ssl_handshaker * impl,unsigned char * bytes,size_t * bytes_size)1162 static tsi_result ssl_handshaker_get_bytes_to_send_to_peer(
1163 tsi_ssl_handshaker* impl, unsigned char* bytes, size_t* bytes_size) {
1164 int bytes_read_from_ssl = 0;
1165 if (bytes == nullptr || bytes_size == nullptr || *bytes_size == 0 ||
1166 *bytes_size > INT_MAX) {
1167 return TSI_INVALID_ARGUMENT;
1168 }
1169 GPR_ASSERT(*bytes_size <= INT_MAX);
1170 bytes_read_from_ssl =
1171 BIO_read(impl->network_io, bytes, static_cast<int>(*bytes_size));
1172 if (bytes_read_from_ssl < 0) {
1173 *bytes_size = 0;
1174 if (!BIO_should_retry(impl->network_io)) {
1175 impl->result = TSI_INTERNAL_ERROR;
1176 return impl->result;
1177 } else {
1178 return TSI_OK;
1179 }
1180 }
1181 *bytes_size = static_cast<size_t>(bytes_read_from_ssl);
1182 return BIO_pending(impl->network_io) == 0 ? TSI_OK : TSI_INCOMPLETE_DATA;
1183 }
1184
ssl_handshaker_get_result(tsi_ssl_handshaker * impl)1185 static tsi_result ssl_handshaker_get_result(tsi_ssl_handshaker* impl) {
1186 if ((impl->result == TSI_HANDSHAKE_IN_PROGRESS) &&
1187 SSL_is_init_finished(impl->ssl)) {
1188 impl->result = TSI_OK;
1189 }
1190 return impl->result;
1191 }
1192
ssl_handshaker_process_bytes_from_peer(tsi_ssl_handshaker * impl,const unsigned char * bytes,size_t * bytes_size)1193 static tsi_result ssl_handshaker_process_bytes_from_peer(
1194 tsi_ssl_handshaker* impl, const unsigned char* bytes, size_t* bytes_size) {
1195 int bytes_written_into_ssl_size = 0;
1196 if (bytes == nullptr || bytes_size == nullptr || *bytes_size > INT_MAX) {
1197 return TSI_INVALID_ARGUMENT;
1198 }
1199 GPR_ASSERT(*bytes_size <= INT_MAX);
1200 bytes_written_into_ssl_size =
1201 BIO_write(impl->network_io, bytes, static_cast<int>(*bytes_size));
1202 if (bytes_written_into_ssl_size < 0) {
1203 gpr_log(GPR_ERROR, "Could not write to memory BIO.");
1204 impl->result = TSI_INTERNAL_ERROR;
1205 return impl->result;
1206 }
1207 *bytes_size = static_cast<size_t>(bytes_written_into_ssl_size);
1208
1209 if (ssl_handshaker_get_result(impl) != TSI_HANDSHAKE_IN_PROGRESS) {
1210 impl->result = TSI_OK;
1211 return impl->result;
1212 } else {
1213 /* Get ready to get some bytes from SSL. */
1214 int ssl_result = SSL_do_handshake(impl->ssl);
1215 ssl_result = SSL_get_error(impl->ssl, ssl_result);
1216 switch (ssl_result) {
1217 case SSL_ERROR_WANT_READ:
1218 if (BIO_pending(impl->network_io) == 0) {
1219 /* We need more data. */
1220 return TSI_INCOMPLETE_DATA;
1221 } else {
1222 return TSI_OK;
1223 }
1224 case SSL_ERROR_NONE:
1225 return TSI_OK;
1226 default: {
1227 char err_str[256];
1228 ERR_error_string_n(ERR_get_error(), err_str, sizeof(err_str));
1229 gpr_log(GPR_ERROR, "Handshake failed with fatal error %s: %s.",
1230 ssl_error_string(ssl_result), err_str);
1231 impl->result = TSI_PROTOCOL_FAILURE;
1232 return impl->result;
1233 }
1234 }
1235 }
1236 }
1237
ssl_handshaker_destroy(tsi_handshaker * self)1238 static void ssl_handshaker_destroy(tsi_handshaker* self) {
1239 tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1240 SSL_free(impl->ssl);
1241 BIO_free(impl->network_io);
1242 gpr_free(impl->outgoing_bytes_buffer);
1243 tsi_ssl_handshaker_factory_unref(impl->factory_ref);
1244 gpr_free(impl);
1245 }
1246
ssl_handshaker_next(tsi_handshaker * self,const unsigned char * received_bytes,size_t received_bytes_size,const unsigned char ** bytes_to_send,size_t * bytes_to_send_size,tsi_handshaker_result ** handshaker_result,tsi_handshaker_on_next_done_cb cb,void * user_data)1247 static tsi_result ssl_handshaker_next(
1248 tsi_handshaker* self, const unsigned char* received_bytes,
1249 size_t received_bytes_size, const unsigned char** bytes_to_send,
1250 size_t* bytes_to_send_size, tsi_handshaker_result** handshaker_result,
1251 tsi_handshaker_on_next_done_cb cb, void* user_data) {
1252 /* Input sanity check. */
1253 if ((received_bytes_size > 0 && received_bytes == nullptr) ||
1254 bytes_to_send == nullptr || bytes_to_send_size == nullptr ||
1255 handshaker_result == nullptr) {
1256 return TSI_INVALID_ARGUMENT;
1257 }
1258 /* If there are received bytes, process them first. */
1259 tsi_ssl_handshaker* impl = reinterpret_cast<tsi_ssl_handshaker*>(self);
1260 tsi_result status = TSI_OK;
1261 size_t bytes_consumed = received_bytes_size;
1262 if (received_bytes_size > 0) {
1263 status = ssl_handshaker_process_bytes_from_peer(impl, received_bytes,
1264 &bytes_consumed);
1265 if (status != TSI_OK) return status;
1266 }
1267 /* Get bytes to send to the peer, if available. */
1268 size_t offset = 0;
1269 do {
1270 size_t to_send_size = impl->outgoing_bytes_buffer_size - offset;
1271 status = ssl_handshaker_get_bytes_to_send_to_peer(
1272 impl, impl->outgoing_bytes_buffer + offset, &to_send_size);
1273 offset += to_send_size;
1274 if (status == TSI_INCOMPLETE_DATA) {
1275 impl->outgoing_bytes_buffer_size *= 2;
1276 impl->outgoing_bytes_buffer = static_cast<unsigned char*>(gpr_realloc(
1277 impl->outgoing_bytes_buffer, impl->outgoing_bytes_buffer_size));
1278 }
1279 } while (status == TSI_INCOMPLETE_DATA);
1280 if (status != TSI_OK) return status;
1281 *bytes_to_send = impl->outgoing_bytes_buffer;
1282 *bytes_to_send_size = offset;
1283 /* If handshake completes, create tsi_handshaker_result. */
1284 if (ssl_handshaker_get_result(impl) == TSI_HANDSHAKE_IN_PROGRESS) {
1285 *handshaker_result = nullptr;
1286 } else {
1287 size_t unused_bytes_size = received_bytes_size - bytes_consumed;
1288 const unsigned char* unused_bytes =
1289 unused_bytes_size == 0 ? nullptr : received_bytes + bytes_consumed;
1290 status = ssl_handshaker_result_create(impl, unused_bytes, unused_bytes_size,
1291 handshaker_result);
1292 if (status == TSI_OK) {
1293 /* Indicates that the handshake has completed and that a handshaker_result
1294 * has been created. */
1295 self->handshaker_result_created = true;
1296 }
1297 }
1298 return status;
1299 }
1300
1301 static const tsi_handshaker_vtable handshaker_vtable = {
1302 nullptr, /* get_bytes_to_send_to_peer -- deprecated */
1303 nullptr, /* process_bytes_from_peer -- deprecated */
1304 nullptr, /* get_result -- deprecated */
1305 nullptr, /* extract_peer -- deprecated */
1306 nullptr, /* create_frame_protector -- deprecated */
1307 ssl_handshaker_destroy,
1308 ssl_handshaker_next,
1309 nullptr, /* shutdown */
1310 };
1311
1312 /* --- tsi_ssl_handshaker_factory common methods. --- */
1313
tsi_ssl_handshaker_resume_session(SSL * ssl,tsi::SslSessionLRUCache * session_cache)1314 static void tsi_ssl_handshaker_resume_session(
1315 SSL* ssl, tsi::SslSessionLRUCache* session_cache) {
1316 const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1317 if (server_name == nullptr) {
1318 return;
1319 }
1320 tsi::SslSessionPtr session = session_cache->Get(server_name);
1321 if (session != nullptr) {
1322 // SSL_set_session internally increments reference counter.
1323 SSL_set_session(ssl, session.get());
1324 }
1325 }
1326
create_tsi_ssl_handshaker(SSL_CTX * ctx,int is_client,const char * server_name_indication,tsi_ssl_handshaker_factory * factory,tsi_handshaker ** handshaker)1327 static tsi_result create_tsi_ssl_handshaker(SSL_CTX* ctx, int is_client,
1328 const char* server_name_indication,
1329 tsi_ssl_handshaker_factory* factory,
1330 tsi_handshaker** handshaker) {
1331 SSL* ssl = SSL_new(ctx);
1332 BIO* network_io = nullptr;
1333 BIO* ssl_io = nullptr;
1334 tsi_ssl_handshaker* impl = nullptr;
1335 *handshaker = nullptr;
1336 if (ctx == nullptr) {
1337 gpr_log(GPR_ERROR, "SSL Context is null. Should never happen.");
1338 return TSI_INTERNAL_ERROR;
1339 }
1340 if (ssl == nullptr) {
1341 return TSI_OUT_OF_RESOURCES;
1342 }
1343 SSL_set_info_callback(ssl, ssl_info_callback);
1344
1345 if (!BIO_new_bio_pair(&network_io, 0, &ssl_io, 0)) {
1346 gpr_log(GPR_ERROR, "BIO_new_bio_pair failed.");
1347 SSL_free(ssl);
1348 return TSI_OUT_OF_RESOURCES;
1349 }
1350 SSL_set_bio(ssl, ssl_io, ssl_io);
1351 if (is_client) {
1352 int ssl_result;
1353 SSL_set_connect_state(ssl);
1354 if (server_name_indication != nullptr) {
1355 if (!SSL_set_tlsext_host_name(ssl, server_name_indication)) {
1356 gpr_log(GPR_ERROR, "Invalid server name indication %s.",
1357 server_name_indication);
1358 SSL_free(ssl);
1359 BIO_free(network_io);
1360 return TSI_INTERNAL_ERROR;
1361 }
1362 }
1363 tsi_ssl_client_handshaker_factory* client_factory =
1364 reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1365 if (client_factory->session_cache != nullptr) {
1366 tsi_ssl_handshaker_resume_session(ssl,
1367 client_factory->session_cache.get());
1368 }
1369 ssl_result = SSL_do_handshake(ssl);
1370 ssl_result = SSL_get_error(ssl, ssl_result);
1371 if (ssl_result != SSL_ERROR_WANT_READ) {
1372 gpr_log(GPR_ERROR,
1373 "Unexpected error received from first SSL_do_handshake call: %s",
1374 ssl_error_string(ssl_result));
1375 SSL_free(ssl);
1376 BIO_free(network_io);
1377 return TSI_INTERNAL_ERROR;
1378 }
1379 } else {
1380 SSL_set_accept_state(ssl);
1381 }
1382
1383 impl = static_cast<tsi_ssl_handshaker*>(gpr_zalloc(sizeof(*impl)));
1384 impl->ssl = ssl;
1385 impl->network_io = network_io;
1386 impl->result = TSI_HANDSHAKE_IN_PROGRESS;
1387 impl->outgoing_bytes_buffer_size =
1388 TSI_SSL_HANDSHAKER_OUTGOING_BUFFER_INITIAL_SIZE;
1389 impl->outgoing_bytes_buffer =
1390 static_cast<unsigned char*>(gpr_zalloc(impl->outgoing_bytes_buffer_size));
1391 impl->base.vtable = &handshaker_vtable;
1392 impl->factory_ref = tsi_ssl_handshaker_factory_ref(factory);
1393
1394 *handshaker = &impl->base;
1395 return TSI_OK;
1396 }
1397
select_protocol_list(const unsigned char ** out,unsigned char * outlen,const unsigned char * client_list,size_t client_list_len,const unsigned char * server_list,size_t server_list_len)1398 static int select_protocol_list(const unsigned char** out,
1399 unsigned char* outlen,
1400 const unsigned char* client_list,
1401 size_t client_list_len,
1402 const unsigned char* server_list,
1403 size_t server_list_len) {
1404 const unsigned char* client_current = client_list;
1405 while (static_cast<unsigned int>(client_current - client_list) <
1406 client_list_len) {
1407 unsigned char client_current_len = *(client_current++);
1408 const unsigned char* server_current = server_list;
1409 while ((server_current >= server_list) &&
1410 static_cast<uintptr_t>(server_current - server_list) <
1411 server_list_len) {
1412 unsigned char server_current_len = *(server_current++);
1413 if ((client_current_len == server_current_len) &&
1414 !memcmp(client_current, server_current, server_current_len)) {
1415 *out = server_current;
1416 *outlen = server_current_len;
1417 return SSL_TLSEXT_ERR_OK;
1418 }
1419 server_current += server_current_len;
1420 }
1421 client_current += client_current_len;
1422 }
1423 return SSL_TLSEXT_ERR_NOACK;
1424 }
1425
1426 /* --- tsi_ssl_client_handshaker_factory methods implementation. --- */
1427
tsi_ssl_client_handshaker_factory_create_handshaker(tsi_ssl_client_handshaker_factory * self,const char * server_name_indication,tsi_handshaker ** handshaker)1428 tsi_result tsi_ssl_client_handshaker_factory_create_handshaker(
1429 tsi_ssl_client_handshaker_factory* self, const char* server_name_indication,
1430 tsi_handshaker** handshaker) {
1431 return create_tsi_ssl_handshaker(self->ssl_context, 1, server_name_indication,
1432 &self->base, handshaker);
1433 }
1434
tsi_ssl_client_handshaker_factory_unref(tsi_ssl_client_handshaker_factory * self)1435 void tsi_ssl_client_handshaker_factory_unref(
1436 tsi_ssl_client_handshaker_factory* self) {
1437 if (self == nullptr) return;
1438 tsi_ssl_handshaker_factory_unref(&self->base);
1439 }
1440
tsi_ssl_client_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1441 static void tsi_ssl_client_handshaker_factory_destroy(
1442 tsi_ssl_handshaker_factory* factory) {
1443 if (factory == nullptr) return;
1444 tsi_ssl_client_handshaker_factory* self =
1445 reinterpret_cast<tsi_ssl_client_handshaker_factory*>(factory);
1446 if (self->ssl_context != nullptr) SSL_CTX_free(self->ssl_context);
1447 if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1448 self->session_cache.reset();
1449 gpr_free(self);
1450 }
1451
client_handshaker_factory_npn_callback(SSL * ssl,unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1452 static int client_handshaker_factory_npn_callback(SSL* ssl, unsigned char** out,
1453 unsigned char* outlen,
1454 const unsigned char* in,
1455 unsigned int inlen,
1456 void* arg) {
1457 tsi_ssl_client_handshaker_factory* factory =
1458 static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1459 return select_protocol_list((const unsigned char**)out, outlen,
1460 factory->alpn_protocol_list,
1461 factory->alpn_protocol_list_length, in, inlen);
1462 }
1463
1464 /* --- tsi_ssl_server_handshaker_factory methods implementation. --- */
1465
tsi_ssl_server_handshaker_factory_create_handshaker(tsi_ssl_server_handshaker_factory * self,tsi_handshaker ** handshaker)1466 tsi_result tsi_ssl_server_handshaker_factory_create_handshaker(
1467 tsi_ssl_server_handshaker_factory* self, tsi_handshaker** handshaker) {
1468 if (self->ssl_context_count == 0) return TSI_INVALID_ARGUMENT;
1469 /* Create the handshaker with the first context. We will switch if needed
1470 because of SNI in ssl_server_handshaker_factory_servername_callback. */
1471 return create_tsi_ssl_handshaker(self->ssl_contexts[0], 0, nullptr,
1472 &self->base, handshaker);
1473 }
1474
tsi_ssl_server_handshaker_factory_unref(tsi_ssl_server_handshaker_factory * self)1475 void tsi_ssl_server_handshaker_factory_unref(
1476 tsi_ssl_server_handshaker_factory* self) {
1477 if (self == nullptr) return;
1478 tsi_ssl_handshaker_factory_unref(&self->base);
1479 }
1480
tsi_ssl_server_handshaker_factory_destroy(tsi_ssl_handshaker_factory * factory)1481 static void tsi_ssl_server_handshaker_factory_destroy(
1482 tsi_ssl_handshaker_factory* factory) {
1483 if (factory == nullptr) return;
1484 tsi_ssl_server_handshaker_factory* self =
1485 reinterpret_cast<tsi_ssl_server_handshaker_factory*>(factory);
1486 size_t i;
1487 for (i = 0; i < self->ssl_context_count; i++) {
1488 if (self->ssl_contexts[i] != nullptr) {
1489 SSL_CTX_free(self->ssl_contexts[i]);
1490 tsi_peer_destruct(&self->ssl_context_x509_subject_names[i]);
1491 }
1492 }
1493 if (self->ssl_contexts != nullptr) gpr_free(self->ssl_contexts);
1494 if (self->ssl_context_x509_subject_names != nullptr) {
1495 gpr_free(self->ssl_context_x509_subject_names);
1496 }
1497 if (self->alpn_protocol_list != nullptr) gpr_free(self->alpn_protocol_list);
1498 gpr_free(self);
1499 }
1500
does_entry_match_name(const char * entry,size_t entry_length,const char * name)1501 static int does_entry_match_name(const char* entry, size_t entry_length,
1502 const char* name) {
1503 const char* dot;
1504 const char* name_subdomain = nullptr;
1505 size_t name_length = strlen(name);
1506 size_t name_subdomain_length;
1507 if (entry_length == 0) return 0;
1508
1509 /* Take care of '.' terminations. */
1510 if (name[name_length - 1] == '.') {
1511 name_length--;
1512 }
1513 if (entry[entry_length - 1] == '.') {
1514 entry_length--;
1515 if (entry_length == 0) return 0;
1516 }
1517
1518 if ((name_length == entry_length) &&
1519 strncmp(name, entry, entry_length) == 0) {
1520 return 1; /* Perfect match. */
1521 }
1522 if (entry[0] != '*') return 0;
1523
1524 /* Wildchar subdomain matching. */
1525 if (entry_length < 3 || entry[1] != '.') { /* At least *.x */
1526 gpr_log(GPR_ERROR, "Invalid wildchar entry.");
1527 return 0;
1528 }
1529 name_subdomain = strchr(name, '.');
1530 if (name_subdomain == nullptr) return 0;
1531 name_subdomain_length = strlen(name_subdomain);
1532 if (name_subdomain_length < 2) return 0;
1533 name_subdomain++; /* Starts after the dot. */
1534 name_subdomain_length--;
1535 entry += 2; /* Remove *. */
1536 entry_length -= 2;
1537 dot = strchr(name_subdomain, '.');
1538 if ((dot == nullptr) || (dot == &name_subdomain[name_subdomain_length - 1])) {
1539 gpr_log(GPR_ERROR, "Invalid toplevel subdomain: %s", name_subdomain);
1540 return 0;
1541 }
1542 if (name_subdomain[name_subdomain_length - 1] == '.') {
1543 name_subdomain_length--;
1544 }
1545 return ((entry_length > 0) && (name_subdomain_length == entry_length) &&
1546 strncmp(entry, name_subdomain, entry_length) == 0);
1547 }
1548
ssl_server_handshaker_factory_servername_callback(SSL * ssl,int * ap,void * arg)1549 static int ssl_server_handshaker_factory_servername_callback(SSL* ssl, int* ap,
1550 void* arg) {
1551 tsi_ssl_server_handshaker_factory* impl =
1552 static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1553 size_t i = 0;
1554 const char* servername = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1555 if (servername == nullptr || strlen(servername) == 0) {
1556 return SSL_TLSEXT_ERR_NOACK;
1557 }
1558
1559 for (i = 0; i < impl->ssl_context_count; i++) {
1560 if (tsi_ssl_peer_matches_name(&impl->ssl_context_x509_subject_names[i],
1561 servername)) {
1562 SSL_set_SSL_CTX(ssl, impl->ssl_contexts[i]);
1563 return SSL_TLSEXT_ERR_OK;
1564 }
1565 }
1566 gpr_log(GPR_ERROR, "No match found for server name: %s.", servername);
1567 return SSL_TLSEXT_ERR_ALERT_WARNING;
1568 }
1569
1570 #if TSI_OPENSSL_ALPN_SUPPORT
server_handshaker_factory_alpn_callback(SSL * ssl,const unsigned char ** out,unsigned char * outlen,const unsigned char * in,unsigned int inlen,void * arg)1571 static int server_handshaker_factory_alpn_callback(
1572 SSL* ssl, const unsigned char** out, unsigned char* outlen,
1573 const unsigned char* in, unsigned int inlen, void* arg) {
1574 tsi_ssl_server_handshaker_factory* factory =
1575 static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1576 return select_protocol_list(out, outlen, in, inlen,
1577 factory->alpn_protocol_list,
1578 factory->alpn_protocol_list_length);
1579 }
1580 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1581
server_handshaker_factory_npn_advertised_callback(SSL * ssl,const unsigned char ** out,unsigned int * outlen,void * arg)1582 static int server_handshaker_factory_npn_advertised_callback(
1583 SSL* ssl, const unsigned char** out, unsigned int* outlen, void* arg) {
1584 tsi_ssl_server_handshaker_factory* factory =
1585 static_cast<tsi_ssl_server_handshaker_factory*>(arg);
1586 *out = factory->alpn_protocol_list;
1587 GPR_ASSERT(factory->alpn_protocol_list_length <= UINT_MAX);
1588 *outlen = static_cast<unsigned int>(factory->alpn_protocol_list_length);
1589 return SSL_TLSEXT_ERR_OK;
1590 }
1591
1592 /// This callback is called when new \a session is established and ready to
1593 /// be cached. This session can be reused for new connections to similar
1594 /// servers at later point of time.
1595 /// It's intended to be used with SSL_CTX_sess_set_new_cb function.
1596 ///
1597 /// It returns 1 if callback takes ownership over \a session and 0 otherwise.
server_handshaker_factory_new_session_callback(SSL * ssl,SSL_SESSION * session)1598 static int server_handshaker_factory_new_session_callback(
1599 SSL* ssl, SSL_SESSION* session) {
1600 SSL_CTX* ssl_context = SSL_get_SSL_CTX(ssl);
1601 if (ssl_context == nullptr) {
1602 return 0;
1603 }
1604 void* arg = SSL_CTX_get_ex_data(ssl_context, g_ssl_ctx_ex_factory_index);
1605 tsi_ssl_client_handshaker_factory* factory =
1606 static_cast<tsi_ssl_client_handshaker_factory*>(arg);
1607 const char* server_name = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1608 if (server_name == nullptr) {
1609 return 0;
1610 }
1611 factory->session_cache->Put(server_name, tsi::SslSessionPtr(session));
1612 // Return 1 to indicate transfered ownership over the given session.
1613 return 1;
1614 }
1615
1616 /* --- tsi_ssl_handshaker_factory constructors. --- */
1617
1618 static tsi_ssl_handshaker_factory_vtable client_handshaker_factory_vtable = {
1619 tsi_ssl_client_handshaker_factory_destroy};
1620
tsi_create_ssl_client_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pair,const char * pem_root_certs,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_client_handshaker_factory ** factory)1621 tsi_result tsi_create_ssl_client_handshaker_factory(
1622 const tsi_ssl_pem_key_cert_pair* pem_key_cert_pair,
1623 const char* pem_root_certs, const char* cipher_suites,
1624 const char** alpn_protocols, uint16_t num_alpn_protocols,
1625 tsi_ssl_client_handshaker_factory** factory) {
1626 tsi_ssl_client_handshaker_options options;
1627 memset(&options, 0, sizeof(options));
1628 options.pem_key_cert_pair = pem_key_cert_pair;
1629 options.pem_root_certs = pem_root_certs;
1630 options.cipher_suites = cipher_suites;
1631 options.alpn_protocols = alpn_protocols;
1632 options.num_alpn_protocols = num_alpn_protocols;
1633 return tsi_create_ssl_client_handshaker_factory_with_options(&options,
1634 factory);
1635 }
1636
tsi_create_ssl_client_handshaker_factory_with_options(const tsi_ssl_client_handshaker_options * options,tsi_ssl_client_handshaker_factory ** factory)1637 tsi_result tsi_create_ssl_client_handshaker_factory_with_options(
1638 const tsi_ssl_client_handshaker_options* options,
1639 tsi_ssl_client_handshaker_factory** factory) {
1640 SSL_CTX* ssl_context = nullptr;
1641 tsi_ssl_client_handshaker_factory* impl = nullptr;
1642 tsi_result result = TSI_OK;
1643
1644 gpr_once_init(&g_init_openssl_once, init_openssl);
1645
1646 if (factory == nullptr) return TSI_INVALID_ARGUMENT;
1647 *factory = nullptr;
1648 if (options->pem_root_certs == nullptr && options->root_store == nullptr) {
1649 return TSI_INVALID_ARGUMENT;
1650 }
1651
1652 ssl_context = SSL_CTX_new(TLSv1_2_method());
1653 if (ssl_context == nullptr) {
1654 gpr_log(GPR_ERROR, "Could not create ssl context.");
1655 return TSI_INVALID_ARGUMENT;
1656 }
1657
1658 impl = static_cast<tsi_ssl_client_handshaker_factory*>(
1659 gpr_zalloc(sizeof(*impl)));
1660 tsi_ssl_handshaker_factory_init(&impl->base);
1661 impl->base.vtable = &client_handshaker_factory_vtable;
1662 impl->ssl_context = ssl_context;
1663 if (options->session_cache != nullptr) {
1664 // Unref is called manually on factory destruction.
1665 impl->session_cache =
1666 reinterpret_cast<tsi::SslSessionLRUCache*>(options->session_cache)
1667 ->Ref();
1668 SSL_CTX_set_ex_data(ssl_context, g_ssl_ctx_ex_factory_index, impl);
1669 SSL_CTX_sess_set_new_cb(ssl_context,
1670 server_handshaker_factory_new_session_callback);
1671 SSL_CTX_set_session_cache_mode(ssl_context, SSL_SESS_CACHE_CLIENT);
1672 }
1673
1674 do {
1675 result = populate_ssl_context(ssl_context, options->pem_key_cert_pair,
1676 options->cipher_suites);
1677 if (result != TSI_OK) break;
1678
1679 #if OPENSSL_VERSION_NUMBER >= 0x10100000
1680 // X509_STORE_up_ref is only available since OpenSSL 1.1.
1681 if (options->root_store != nullptr) {
1682 X509_STORE_up_ref(options->root_store->store);
1683 SSL_CTX_set_cert_store(ssl_context, options->root_store->store);
1684 }
1685 #endif
1686 if (OPENSSL_VERSION_NUMBER < 0x10100000 || options->root_store == nullptr) {
1687 result = ssl_ctx_load_verification_certs(
1688 ssl_context, options->pem_root_certs, strlen(options->pem_root_certs),
1689 nullptr);
1690 if (result != TSI_OK) {
1691 gpr_log(GPR_ERROR, "Cannot load server root certificates.");
1692 break;
1693 }
1694 }
1695
1696 if (options->num_alpn_protocols != 0) {
1697 result = build_alpn_protocol_name_list(
1698 options->alpn_protocols, options->num_alpn_protocols,
1699 &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
1700 if (result != TSI_OK) {
1701 gpr_log(GPR_ERROR, "Building alpn list failed with error %s.",
1702 tsi_result_to_string(result));
1703 break;
1704 }
1705 #if TSI_OPENSSL_ALPN_SUPPORT
1706 GPR_ASSERT(impl->alpn_protocol_list_length < UINT_MAX);
1707 if (SSL_CTX_set_alpn_protos(
1708 ssl_context, impl->alpn_protocol_list,
1709 static_cast<unsigned int>(impl->alpn_protocol_list_length))) {
1710 gpr_log(GPR_ERROR, "Could not set alpn protocol list to context.");
1711 result = TSI_INVALID_ARGUMENT;
1712 break;
1713 }
1714 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1715 SSL_CTX_set_next_proto_select_cb(
1716 ssl_context, client_handshaker_factory_npn_callback, impl);
1717 }
1718 } while (0);
1719 if (result != TSI_OK) {
1720 tsi_ssl_handshaker_factory_unref(&impl->base);
1721 return result;
1722 }
1723 SSL_CTX_set_verify(ssl_context, SSL_VERIFY_PEER, nullptr);
1724 /* TODO(jboeuf): Add revocation verification. */
1725
1726 *factory = impl;
1727 return TSI_OK;
1728 }
1729
1730 static tsi_ssl_handshaker_factory_vtable server_handshaker_factory_vtable = {
1731 tsi_ssl_server_handshaker_factory_destroy};
1732
tsi_create_ssl_server_handshaker_factory(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,int force_client_auth,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)1733 tsi_result tsi_create_ssl_server_handshaker_factory(
1734 const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
1735 size_t num_key_cert_pairs, const char* pem_client_root_certs,
1736 int force_client_auth, const char* cipher_suites,
1737 const char** alpn_protocols, uint16_t num_alpn_protocols,
1738 tsi_ssl_server_handshaker_factory** factory) {
1739 return tsi_create_ssl_server_handshaker_factory_ex(
1740 pem_key_cert_pairs, num_key_cert_pairs, pem_client_root_certs,
1741 force_client_auth ? TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY
1742 : TSI_DONT_REQUEST_CLIENT_CERTIFICATE,
1743 cipher_suites, alpn_protocols, num_alpn_protocols, factory);
1744 }
1745
tsi_create_ssl_server_handshaker_factory_ex(const tsi_ssl_pem_key_cert_pair * pem_key_cert_pairs,size_t num_key_cert_pairs,const char * pem_client_root_certs,tsi_client_certificate_request_type client_certificate_request,const char * cipher_suites,const char ** alpn_protocols,uint16_t num_alpn_protocols,tsi_ssl_server_handshaker_factory ** factory)1746 tsi_result tsi_create_ssl_server_handshaker_factory_ex(
1747 const tsi_ssl_pem_key_cert_pair* pem_key_cert_pairs,
1748 size_t num_key_cert_pairs, const char* pem_client_root_certs,
1749 tsi_client_certificate_request_type client_certificate_request,
1750 const char* cipher_suites, const char** alpn_protocols,
1751 uint16_t num_alpn_protocols, tsi_ssl_server_handshaker_factory** factory) {
1752 tsi_ssl_server_handshaker_options options;
1753 memset(&options, 0, sizeof(options));
1754 options.pem_key_cert_pairs = pem_key_cert_pairs;
1755 options.num_key_cert_pairs = num_key_cert_pairs;
1756 options.pem_client_root_certs = pem_client_root_certs;
1757 options.client_certificate_request = client_certificate_request;
1758 options.cipher_suites = cipher_suites;
1759 options.alpn_protocols = alpn_protocols;
1760 options.num_alpn_protocols = num_alpn_protocols;
1761 return tsi_create_ssl_server_handshaker_factory_with_options(&options,
1762 factory);
1763 }
1764
tsi_create_ssl_server_handshaker_factory_with_options(const tsi_ssl_server_handshaker_options * options,tsi_ssl_server_handshaker_factory ** factory)1765 tsi_result tsi_create_ssl_server_handshaker_factory_with_options(
1766 const tsi_ssl_server_handshaker_options* options,
1767 tsi_ssl_server_handshaker_factory** factory) {
1768 tsi_ssl_server_handshaker_factory* impl = nullptr;
1769 tsi_result result = TSI_OK;
1770 size_t i = 0;
1771
1772 gpr_once_init(&g_init_openssl_once, init_openssl);
1773
1774 if (factory == nullptr) return TSI_INVALID_ARGUMENT;
1775 *factory = nullptr;
1776 if (options->num_key_cert_pairs == 0 ||
1777 options->pem_key_cert_pairs == nullptr) {
1778 return TSI_INVALID_ARGUMENT;
1779 }
1780
1781 impl = static_cast<tsi_ssl_server_handshaker_factory*>(
1782 gpr_zalloc(sizeof(*impl)));
1783 tsi_ssl_handshaker_factory_init(&impl->base);
1784 impl->base.vtable = &server_handshaker_factory_vtable;
1785
1786 impl->ssl_contexts = static_cast<SSL_CTX**>(
1787 gpr_zalloc(options->num_key_cert_pairs * sizeof(SSL_CTX*)));
1788 impl->ssl_context_x509_subject_names = static_cast<tsi_peer*>(
1789 gpr_zalloc(options->num_key_cert_pairs * sizeof(tsi_peer)));
1790 if (impl->ssl_contexts == nullptr ||
1791 impl->ssl_context_x509_subject_names == nullptr) {
1792 tsi_ssl_handshaker_factory_unref(&impl->base);
1793 return TSI_OUT_OF_RESOURCES;
1794 }
1795 impl->ssl_context_count = options->num_key_cert_pairs;
1796
1797 if (options->num_alpn_protocols > 0) {
1798 result = build_alpn_protocol_name_list(
1799 options->alpn_protocols, options->num_alpn_protocols,
1800 &impl->alpn_protocol_list, &impl->alpn_protocol_list_length);
1801 if (result != TSI_OK) {
1802 tsi_ssl_handshaker_factory_unref(&impl->base);
1803 return result;
1804 }
1805 }
1806
1807 for (i = 0; i < options->num_key_cert_pairs; i++) {
1808 do {
1809 impl->ssl_contexts[i] = SSL_CTX_new(TLSv1_2_method());
1810 if (impl->ssl_contexts[i] == nullptr) {
1811 gpr_log(GPR_ERROR, "Could not create ssl context.");
1812 result = TSI_OUT_OF_RESOURCES;
1813 break;
1814 }
1815 result = populate_ssl_context(impl->ssl_contexts[i],
1816 &options->pem_key_cert_pairs[i],
1817 options->cipher_suites);
1818 if (result != TSI_OK) break;
1819
1820 // TODO(elessar): Provide ability to disable session ticket keys.
1821
1822 // Allow client cache sessions (it's needed for OpenSSL only).
1823 int set_sid_ctx_result = SSL_CTX_set_session_id_context(
1824 impl->ssl_contexts[i], kSslSessionIdContext,
1825 GPR_ARRAY_SIZE(kSslSessionIdContext));
1826 if (set_sid_ctx_result == 0) {
1827 gpr_log(GPR_ERROR, "Failed to set session id context.");
1828 result = TSI_INTERNAL_ERROR;
1829 break;
1830 }
1831
1832 if (options->session_ticket_key != nullptr) {
1833 if (SSL_CTX_set_tlsext_ticket_keys(
1834 impl->ssl_contexts[i],
1835 const_cast<char*>(options->session_ticket_key),
1836 options->session_ticket_key_size) == 0) {
1837 gpr_log(GPR_ERROR, "Invalid STEK size.");
1838 result = TSI_INVALID_ARGUMENT;
1839 break;
1840 }
1841 }
1842
1843 if (options->pem_client_root_certs != nullptr) {
1844 STACK_OF(X509_NAME)* root_names = nullptr;
1845 result = ssl_ctx_load_verification_certs(
1846 impl->ssl_contexts[i], options->pem_client_root_certs,
1847 strlen(options->pem_client_root_certs), &root_names);
1848 if (result != TSI_OK) {
1849 gpr_log(GPR_ERROR, "Invalid verification certs.");
1850 break;
1851 }
1852 SSL_CTX_set_client_CA_list(impl->ssl_contexts[i], root_names);
1853 switch (options->client_certificate_request) {
1854 case TSI_DONT_REQUEST_CLIENT_CERTIFICATE:
1855 SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_NONE, nullptr);
1856 break;
1857 case TSI_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
1858 SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER,
1859 NullVerifyCallback);
1860 break;
1861 case TSI_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
1862 SSL_CTX_set_verify(impl->ssl_contexts[i], SSL_VERIFY_PEER, nullptr);
1863 break;
1864 case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
1865 SSL_CTX_set_verify(
1866 impl->ssl_contexts[i],
1867 SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
1868 NullVerifyCallback);
1869 break;
1870 case TSI_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
1871 SSL_CTX_set_verify(
1872 impl->ssl_contexts[i],
1873 SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr);
1874 break;
1875 }
1876 /* TODO(jboeuf): Add revocation verification. */
1877 }
1878
1879 result = extract_x509_subject_names_from_pem_cert(
1880 options->pem_key_cert_pairs[i].cert_chain,
1881 &impl->ssl_context_x509_subject_names[i]);
1882 if (result != TSI_OK) break;
1883
1884 SSL_CTX_set_tlsext_servername_callback(
1885 impl->ssl_contexts[i],
1886 ssl_server_handshaker_factory_servername_callback);
1887 SSL_CTX_set_tlsext_servername_arg(impl->ssl_contexts[i], impl);
1888 #if TSI_OPENSSL_ALPN_SUPPORT
1889 SSL_CTX_set_alpn_select_cb(impl->ssl_contexts[i],
1890 server_handshaker_factory_alpn_callback, impl);
1891 #endif /* TSI_OPENSSL_ALPN_SUPPORT */
1892 SSL_CTX_set_next_protos_advertised_cb(
1893 impl->ssl_contexts[i],
1894 server_handshaker_factory_npn_advertised_callback, impl);
1895 } while (0);
1896
1897 if (result != TSI_OK) {
1898 tsi_ssl_handshaker_factory_unref(&impl->base);
1899 return result;
1900 }
1901 }
1902
1903 *factory = impl;
1904 return TSI_OK;
1905 }
1906
1907 /* --- tsi_ssl utils. --- */
1908
tsi_ssl_peer_matches_name(const tsi_peer * peer,const char * name)1909 int tsi_ssl_peer_matches_name(const tsi_peer* peer, const char* name) {
1910 size_t i = 0;
1911 size_t san_count = 0;
1912 const tsi_peer_property* cn_property = nullptr;
1913 int like_ip = looks_like_ip_address(name);
1914
1915 /* Check the SAN first. */
1916 for (i = 0; i < peer->property_count; i++) {
1917 const tsi_peer_property* property = &peer->properties[i];
1918 if (property->name == nullptr) continue;
1919 if (strcmp(property->name,
1920 TSI_X509_SUBJECT_ALTERNATIVE_NAME_PEER_PROPERTY) == 0) {
1921 san_count++;
1922
1923 if (!like_ip && does_entry_match_name(property->value.data,
1924 property->value.length, name)) {
1925 return 1;
1926 } else if (like_ip &&
1927 strncmp(name, property->value.data, property->value.length) ==
1928 0 &&
1929 strlen(name) == property->value.length) {
1930 /* IP Addresses are exact matches only. */
1931 return 1;
1932 }
1933 } else if (strcmp(property->name,
1934 TSI_X509_SUBJECT_COMMON_NAME_PEER_PROPERTY) == 0) {
1935 cn_property = property;
1936 }
1937 }
1938
1939 /* If there's no SAN, try the CN, but only if its not like an IP Address */
1940 if (san_count == 0 && cn_property != nullptr && !like_ip) {
1941 if (does_entry_match_name(cn_property->value.data,
1942 cn_property->value.length, name)) {
1943 return 1;
1944 }
1945 }
1946
1947 return 0; /* Not found. */
1948 }
1949
1950 /* --- Testing support. --- */
tsi_ssl_handshaker_factory_swap_vtable(tsi_ssl_handshaker_factory * factory,tsi_ssl_handshaker_factory_vtable * new_vtable)1951 const tsi_ssl_handshaker_factory_vtable* tsi_ssl_handshaker_factory_swap_vtable(
1952 tsi_ssl_handshaker_factory* factory,
1953 tsi_ssl_handshaker_factory_vtable* new_vtable) {
1954 GPR_ASSERT(factory != nullptr);
1955 GPR_ASSERT(factory->vtable != nullptr);
1956
1957 const tsi_ssl_handshaker_factory_vtable* orig_vtable = factory->vtable;
1958 factory->vtable = new_vtable;
1959 return orig_vtable;
1960 }
1961