• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 1995-2016 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved.
4  *
5  * Licensed under the OpenSSL license (the "License").  You may not use
6  * this file except in compliance with the License.  You can obtain a copy
7  * in the file LICENSE in the source distribution or at
8  * https://www.openssl.org/source/license.html
9  */
10 
11 #include <openssl/ssl.h>
12 
13 #include <assert.h>
14 #include <limits.h>
15 #include <string.h>
16 
17 #include <tuple>
18 
19 #include <openssl/buf.h>
20 #include <openssl/bytestring.h>
21 #include <openssl/err.h>
22 #include <openssl/evp.h>
23 #include <openssl/md5.h>
24 #include <openssl/mem.h>
25 #include <openssl/nid.h>
26 #include <openssl/rand.h>
27 #include <openssl/sha.h>
28 
29 #include "../crypto/internal.h"
30 #include "internal.h"
31 
32 
33 BSSL_NAMESPACE_BEGIN
34 
add_record_to_flight(SSL * ssl,uint8_t type,Span<const uint8_t> in)35 static bool add_record_to_flight(SSL *ssl, uint8_t type,
36                                  Span<const uint8_t> in) {
37   // The caller should have flushed |pending_hs_data| first.
38   assert(!ssl->s3->pending_hs_data);
39   // We'll never add a flight while in the process of writing it out.
40   assert(ssl->s3->pending_flight_offset == 0);
41 
42   if (ssl->s3->pending_flight == nullptr) {
43     ssl->s3->pending_flight.reset(BUF_MEM_new());
44     if (ssl->s3->pending_flight == nullptr) {
45       return false;
46     }
47   }
48 
49   size_t max_out = in.size() + SSL_max_seal_overhead(ssl);
50   size_t new_cap = ssl->s3->pending_flight->length + max_out;
51   if (max_out < in.size() || new_cap < max_out) {
52     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
53     return false;
54   }
55 
56   size_t len;
57   if (!BUF_MEM_reserve(ssl->s3->pending_flight.get(), new_cap) ||
58       !tls_seal_record(ssl,
59                        (uint8_t *)ssl->s3->pending_flight->data +
60                            ssl->s3->pending_flight->length,
61                        &len, max_out, type, in.data(), in.size())) {
62     return false;
63   }
64 
65   ssl->s3->pending_flight->length += len;
66   return true;
67 }
68 
tls_init_message(const SSL * ssl,CBB * cbb,CBB * body,uint8_t type)69 bool tls_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
70   // Pick a modest size hint to save most of the |realloc| calls.
71   if (!CBB_init(cbb, 64) ||      //
72       !CBB_add_u8(cbb, type) ||  //
73       !CBB_add_u24_length_prefixed(cbb, body)) {
74     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
75     CBB_cleanup(cbb);
76     return false;
77   }
78 
79   return true;
80 }
81 
tls_finish_message(const SSL * ssl,CBB * cbb,Array<uint8_t> * out_msg)82 bool tls_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg) {
83   return CBBFinishArray(cbb, out_msg);
84 }
85 
tls_add_message(SSL * ssl,Array<uint8_t> msg)86 bool tls_add_message(SSL *ssl, Array<uint8_t> msg) {
87   // Pack handshake data into the minimal number of records. This avoids
88   // unnecessary encryption overhead, notably in TLS 1.3 where we send several
89   // encrypted messages in a row. For now, we do not do this for the null
90   // cipher. The benefit is smaller and there is a risk of breaking buggy
91   // implementations.
92   //
93   // TODO(crbug.com/374991962): See if we can do this uniformly.
94   Span<const uint8_t> rest = msg;
95   if (!SSL_is_quic(ssl) && ssl->s3->aead_write_ctx->is_null_cipher()) {
96     while (!rest.empty()) {
97       Span<const uint8_t> chunk = rest.subspan(0, ssl->max_send_fragment);
98       rest = rest.subspan(chunk.size());
99 
100       if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) {
101         return false;
102       }
103     }
104   } else {
105     while (!rest.empty()) {
106       // Flush if |pending_hs_data| is full.
107       if (ssl->s3->pending_hs_data &&
108           ssl->s3->pending_hs_data->length >= ssl->max_send_fragment &&
109           !tls_flush_pending_hs_data(ssl)) {
110         return false;
111       }
112 
113       size_t pending_len =
114           ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0;
115       Span<const uint8_t> chunk =
116           rest.subspan(0, ssl->max_send_fragment - pending_len);
117       assert(!chunk.empty());
118       rest = rest.subspan(chunk.size());
119 
120       if (!ssl->s3->pending_hs_data) {
121         ssl->s3->pending_hs_data.reset(BUF_MEM_new());
122       }
123       if (!ssl->s3->pending_hs_data ||
124           !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(),
125                           chunk.size())) {
126         return false;
127       }
128     }
129   }
130 
131   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HANDSHAKE, msg);
132   // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript on
133   // hs.
134   if (ssl->s3->hs != NULL &&  //
135       !ssl->s3->hs->transcript.Update(msg)) {
136     return false;
137   }
138   return true;
139 }
140 
tls_flush_pending_hs_data(SSL * ssl)141 bool tls_flush_pending_hs_data(SSL *ssl) {
142   if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) {
143     return true;
144   }
145 
146   UniquePtr<BUF_MEM> pending_hs_data = std::move(ssl->s3->pending_hs_data);
147   auto data = Span(reinterpret_cast<const uint8_t *>(pending_hs_data->data),
148                    pending_hs_data->length);
149   if (SSL_is_quic(ssl)) {
150     if ((ssl->s3->hs == nullptr || !ssl->s3->hs->hints_requested) &&
151         !ssl->quic_method->add_handshake_data(ssl, ssl->s3->quic_write_level,
152                                               data.data(), data.size())) {
153       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
154       return false;
155     }
156     return true;
157   }
158 
159   return add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, data);
160 }
161 
tls_add_change_cipher_spec(SSL * ssl)162 bool tls_add_change_cipher_spec(SSL *ssl) {
163   if (SSL_is_quic(ssl)) {
164     return true;
165   }
166 
167   static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
168   if (!tls_flush_pending_hs_data(ssl) ||
169       !add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC,
170                             kChangeCipherSpec)) {
171     return false;
172   }
173 
174   ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_CHANGE_CIPHER_SPEC,
175                       kChangeCipherSpec);
176   return true;
177 }
178 
tls_flush(SSL * ssl)179 int tls_flush(SSL *ssl) {
180   if (!tls_flush_pending_hs_data(ssl)) {
181     return -1;
182   }
183 
184   if (SSL_is_quic(ssl)) {
185     if (ssl->s3->write_shutdown != ssl_shutdown_none) {
186       OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
187       return -1;
188     }
189 
190     if (!ssl->quic_method->flush_flight(ssl)) {
191       OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR);
192       return -1;
193     }
194   }
195 
196   if (ssl->s3->pending_flight == nullptr) {
197     return 1;
198   }
199 
200   if (ssl->s3->write_shutdown != ssl_shutdown_none) {
201     OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
202     return -1;
203   }
204 
205   static_assert(INT_MAX <= 0xffffffff, "int is larger than 32 bits");
206   if (ssl->s3->pending_flight->length > INT_MAX) {
207     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
208     return -1;
209   }
210 
211   // If there is pending data in the write buffer, it must be flushed out before
212   // any new data in pending_flight.
213   if (!ssl->s3->write_buffer.empty()) {
214     int ret = ssl_write_buffer_flush(ssl);
215     if (ret <= 0) {
216       ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
217       return ret;
218     }
219   }
220 
221   if (ssl->wbio == nullptr) {
222     OPENSSL_PUT_ERROR(SSL, SSL_R_BIO_NOT_SET);
223     return -1;
224   }
225 
226   // Write the pending flight.
227   while (ssl->s3->pending_flight_offset < ssl->s3->pending_flight->length) {
228     int ret = BIO_write(
229         ssl->wbio.get(),
230         ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
231         ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset);
232     if (ret <= 0) {
233       ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
234       return ret;
235     }
236 
237     ssl->s3->pending_flight_offset += ret;
238   }
239 
240   if (BIO_flush(ssl->wbio.get()) <= 0) {
241     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
242     return -1;
243   }
244 
245   ssl->s3->pending_flight.reset();
246   ssl->s3->pending_flight_offset = 0;
247   return 1;
248 }
249 
read_v2_client_hello(SSL * ssl,size_t * out_consumed,Span<const uint8_t> in)250 static ssl_open_record_t read_v2_client_hello(SSL *ssl, size_t *out_consumed,
251                                               Span<const uint8_t> in) {
252   *out_consumed = 0;
253   assert(in.size() >= SSL3_RT_HEADER_LENGTH);
254   // Determine the length of the V2ClientHello.
255   size_t msg_length = ((in[0] & 0x7f) << 8) | in[1];
256   if (msg_length > (1024 * 4)) {
257     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_TOO_LARGE);
258     return ssl_open_record_error;
259   }
260   if (msg_length < SSL3_RT_HEADER_LENGTH - 2) {
261     // Reject lengths that are too short early. We have already read
262     // |SSL3_RT_HEADER_LENGTH| bytes, so we should not attempt to process an
263     // (invalid) V2ClientHello which would be shorter than that.
264     OPENSSL_PUT_ERROR(SSL, SSL_R_RECORD_LENGTH_MISMATCH);
265     return ssl_open_record_error;
266   }
267 
268   // Ask for the remainder of the V2ClientHello.
269   if (in.size() < 2 + msg_length) {
270     *out_consumed = 2 + msg_length;
271     return ssl_open_record_partial;
272   }
273 
274   CBS v2_client_hello = CBS(in.subspan(2, msg_length));
275   // The V2ClientHello without the length is incorporated into the handshake
276   // hash. This is only ever called at the start of the handshake, so hs is
277   // guaranteed to be non-NULL.
278   if (!ssl->s3->hs->transcript.Update(v2_client_hello)) {
279     return ssl_open_record_error;
280   }
281 
282   ssl_do_msg_callback(ssl, 0 /* read */, 0 /* V2ClientHello */,
283                       v2_client_hello);
284 
285   uint8_t msg_type;
286   uint16_t version, cipher_spec_length, session_id_length, challenge_length;
287   CBS cipher_specs, session_id, challenge;
288   if (!CBS_get_u8(&v2_client_hello, &msg_type) ||
289       !CBS_get_u16(&v2_client_hello, &version) ||
290       !CBS_get_u16(&v2_client_hello, &cipher_spec_length) ||
291       !CBS_get_u16(&v2_client_hello, &session_id_length) ||
292       !CBS_get_u16(&v2_client_hello, &challenge_length) ||
293       !CBS_get_bytes(&v2_client_hello, &cipher_specs, cipher_spec_length) ||
294       !CBS_get_bytes(&v2_client_hello, &session_id, session_id_length) ||
295       !CBS_get_bytes(&v2_client_hello, &challenge, challenge_length) ||
296       CBS_len(&v2_client_hello) != 0) {
297     OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
298     return ssl_open_record_error;
299   }
300 
301   // msg_type has already been checked.
302   assert(msg_type == SSL2_MT_CLIENT_HELLO);
303 
304   // The client_random is the V2ClientHello challenge. Truncate or left-pad with
305   // zeros as needed.
306   size_t rand_len = CBS_len(&challenge);
307   if (rand_len > SSL3_RANDOM_SIZE) {
308     rand_len = SSL3_RANDOM_SIZE;
309   }
310   uint8_t random[SSL3_RANDOM_SIZE];
311   OPENSSL_memset(random, 0, SSL3_RANDOM_SIZE);
312   OPENSSL_memcpy(random + (SSL3_RANDOM_SIZE - rand_len), CBS_data(&challenge),
313                  rand_len);
314 
315   // Write out an equivalent TLS ClientHello directly to the handshake buffer.
316   size_t max_v3_client_hello = SSL3_HM_HEADER_LENGTH + 2 /* version */ +
317                                SSL3_RANDOM_SIZE + 1 /* session ID length */ +
318                                2 /* cipher list length */ +
319                                CBS_len(&cipher_specs) / 3 * 2 +
320                                1 /* compression length */ + 1 /* compression */;
321   ScopedCBB client_hello;
322   CBB hello_body, cipher_suites;
323   if (!ssl->s3->hs_buf) {
324     ssl->s3->hs_buf.reset(BUF_MEM_new());
325   }
326   if (!ssl->s3->hs_buf ||
327       !BUF_MEM_reserve(ssl->s3->hs_buf.get(), max_v3_client_hello) ||
328       !CBB_init_fixed(client_hello.get(), (uint8_t *)ssl->s3->hs_buf->data,
329                       ssl->s3->hs_buf->max) ||
330       !CBB_add_u8(client_hello.get(), SSL3_MT_CLIENT_HELLO) ||
331       !CBB_add_u24_length_prefixed(client_hello.get(), &hello_body) ||
332       !CBB_add_u16(&hello_body, version) ||
333       !CBB_add_bytes(&hello_body, random, SSL3_RANDOM_SIZE) ||
334       // No session id.
335       !CBB_add_u8(&hello_body, 0) ||
336       !CBB_add_u16_length_prefixed(&hello_body, &cipher_suites)) {
337     return ssl_open_record_error;
338   }
339 
340   // Copy the cipher suites.
341   while (CBS_len(&cipher_specs) > 0) {
342     uint32_t cipher_spec;
343     if (!CBS_get_u24(&cipher_specs, &cipher_spec)) {
344       OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
345       return ssl_open_record_error;
346     }
347 
348     // Skip SSLv2 ciphers.
349     if ((cipher_spec & 0xff0000) != 0) {
350       continue;
351     }
352     if (!CBB_add_u16(&cipher_suites, cipher_spec)) {
353       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
354       return ssl_open_record_error;
355     }
356   }
357 
358   // Add the null compression scheme and finish.
359   if (!CBB_add_u8(&hello_body, 1) ||  //
360       !CBB_add_u8(&hello_body, 0) ||  //
361       !CBB_finish(client_hello.get(), NULL, &ssl->s3->hs_buf->length)) {
362     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
363     return ssl_open_record_error;
364   }
365 
366   *out_consumed = 2 + msg_length;
367   ssl->s3->is_v2_hello = true;
368   return ssl_open_record_success;
369 }
370 
parse_message(const SSL * ssl,SSLMessage * out,size_t * out_bytes_needed)371 static bool parse_message(const SSL *ssl, SSLMessage *out,
372                           size_t *out_bytes_needed) {
373   if (!ssl->s3->hs_buf) {
374     *out_bytes_needed = 4;
375     return false;
376   }
377 
378   CBS cbs;
379   uint32_t len;
380   CBS_init(&cbs, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
381            ssl->s3->hs_buf->length);
382   if (!CBS_get_u8(&cbs, &out->type) ||  //
383       !CBS_get_u24(&cbs, &len)) {
384     *out_bytes_needed = 4;
385     return false;
386   }
387 
388   if (!CBS_get_bytes(&cbs, &out->body, len)) {
389     *out_bytes_needed = 4 + len;
390     return false;
391   }
392 
393   CBS_init(&out->raw, reinterpret_cast<const uint8_t *>(ssl->s3->hs_buf->data),
394            4 + len);
395   out->is_v2_hello = ssl->s3->is_v2_hello;
396   return true;
397 }
398 
tls_get_message(const SSL * ssl,SSLMessage * out)399 bool tls_get_message(const SSL *ssl, SSLMessage *out) {
400   size_t unused;
401   if (!parse_message(ssl, out, &unused)) {
402     return false;
403   }
404   if (!ssl->s3->has_message) {
405     if (!out->is_v2_hello) {
406       ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
407     }
408     ssl->s3->has_message = true;
409   }
410   return true;
411 }
412 
tls_can_accept_handshake_data(const SSL * ssl,uint8_t * out_alert)413 bool tls_can_accept_handshake_data(const SSL *ssl, uint8_t *out_alert) {
414   // If there is a complete message, the caller must have consumed it first.
415   SSLMessage msg;
416   size_t bytes_needed;
417   if (parse_message(ssl, &msg, &bytes_needed)) {
418     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
419     *out_alert = SSL_AD_INTERNAL_ERROR;
420     return false;
421   }
422 
423   // Enforce the limit so the peer cannot force us to buffer 16MB.
424   if (bytes_needed > 4 + ssl_max_handshake_message_len(ssl)) {
425     OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
426     *out_alert = SSL_AD_ILLEGAL_PARAMETER;
427     return false;
428   }
429 
430   return true;
431 }
432 
tls_has_unprocessed_handshake_data(const SSL * ssl)433 bool tls_has_unprocessed_handshake_data(const SSL *ssl) {
434   size_t msg_len = 0;
435   if (ssl->s3->has_message) {
436     SSLMessage msg;
437     size_t unused;
438     if (parse_message(ssl, &msg, &unused)) {
439       msg_len = CBS_len(&msg.raw);
440     }
441   }
442 
443   return ssl->s3->hs_buf && ssl->s3->hs_buf->length > msg_len;
444 }
445 
tls_append_handshake_data(SSL * ssl,Span<const uint8_t> data)446 bool tls_append_handshake_data(SSL *ssl, Span<const uint8_t> data) {
447   // Re-create the handshake buffer if needed.
448   if (!ssl->s3->hs_buf) {
449     ssl->s3->hs_buf.reset(BUF_MEM_new());
450   }
451   return ssl->s3->hs_buf &&
452          BUF_MEM_append(ssl->s3->hs_buf.get(), data.data(), data.size());
453 }
454 
tls_open_handshake(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)455 ssl_open_record_t tls_open_handshake(SSL *ssl, size_t *out_consumed,
456                                      uint8_t *out_alert, Span<uint8_t> in) {
457   *out_consumed = 0;
458   // Bypass the record layer for the first message to handle V2ClientHello.
459   if (ssl->server && !ssl->s3->v2_hello_done) {
460     // Ask for the first 5 bytes, the size of the TLS record header. This is
461     // sufficient to detect a V2ClientHello and ensures that we never read
462     // beyond the first record.
463     if (in.size() < SSL3_RT_HEADER_LENGTH) {
464       *out_consumed = SSL3_RT_HEADER_LENGTH;
465       return ssl_open_record_partial;
466     }
467 
468     // Some dedicated error codes for protocol mixups should the application
469     // wish to interpret them differently. (These do not overlap with
470     // ClientHello or V2ClientHello.)
471     auto str = bssl::BytesAsStringView(in);
472     if (str.substr(0, 4) == "GET " ||   //
473         str.substr(0, 5) == "POST " ||  //
474         str.substr(0, 5) == "HEAD " ||  //
475         str.substr(0, 4) == "PUT ") {
476       OPENSSL_PUT_ERROR(SSL, SSL_R_HTTP_REQUEST);
477       *out_alert = 0;
478       return ssl_open_record_error;
479     }
480     if (str.substr(0, 5) == "CONNE") {
481       OPENSSL_PUT_ERROR(SSL, SSL_R_HTTPS_PROXY_REQUEST);
482       *out_alert = 0;
483       return ssl_open_record_error;
484     }
485 
486     // Check for a V2ClientHello.
487     if ((in[0] & 0x80) != 0 && in[2] == SSL2_MT_CLIENT_HELLO &&
488         in[3] == SSL3_VERSION_MAJOR) {
489       auto ret = read_v2_client_hello(ssl, out_consumed, in);
490       if (ret == ssl_open_record_error) {
491         *out_alert = 0;
492       } else if (ret == ssl_open_record_success) {
493         ssl->s3->v2_hello_done = true;
494       }
495       return ret;
496     }
497 
498     ssl->s3->v2_hello_done = true;
499   }
500 
501   uint8_t type;
502   Span<uint8_t> body;
503   auto ret = tls_open_record(ssl, &type, &body, out_consumed, out_alert, in);
504   if (ret != ssl_open_record_success) {
505     return ret;
506   }
507 
508   if (type != SSL3_RT_HANDSHAKE) {
509     OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
510     *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
511     return ssl_open_record_error;
512   }
513 
514   // Append the entire handshake record to the buffer.
515   if (!tls_append_handshake_data(ssl, body)) {
516     *out_alert = SSL_AD_INTERNAL_ERROR;
517     return ssl_open_record_error;
518   }
519 
520   return ssl_open_record_success;
521 }
522 
tls_next_message(SSL * ssl)523 void tls_next_message(SSL *ssl) {
524   SSLMessage msg;
525   if (!tls_get_message(ssl, &msg) ||  //
526       !ssl->s3->hs_buf ||             //
527       ssl->s3->hs_buf->length < CBS_len(&msg.raw)) {
528     assert(0);
529     return;
530   }
531 
532   OPENSSL_memmove(ssl->s3->hs_buf->data,
533                   ssl->s3->hs_buf->data + CBS_len(&msg.raw),
534                   ssl->s3->hs_buf->length - CBS_len(&msg.raw));
535   ssl->s3->hs_buf->length -= CBS_len(&msg.raw);
536   ssl->s3->is_v2_hello = false;
537   ssl->s3->has_message = false;
538 
539   // Post-handshake messages are rare, so release the buffer after every
540   // message. During the handshake, |on_handshake_complete| will release it.
541   if (!SSL_in_init(ssl) && ssl->s3->hs_buf->length == 0) {
542     ssl->s3->hs_buf.reset();
543   }
544 }
545 
546 namespace {
547 
548 class CipherScorer {
549  public:
550   using Score = int;
551   static constexpr Score kMinScore = 0;
552 
553   virtual ~CipherScorer() = default;
554 
555   virtual Score Evaluate(const SSL_CIPHER *cipher) const = 0;
556 };
557 
558 // AesHwCipherScorer scores cipher suites based on whether AES is supported in
559 // hardware.
560 class AesHwCipherScorer : public CipherScorer {
561  public:
AesHwCipherScorer(bool has_aes_hw)562   explicit AesHwCipherScorer(bool has_aes_hw) : aes_is_fine_(has_aes_hw) {}
563 
564   virtual ~AesHwCipherScorer() override = default;
565 
Evaluate(const SSL_CIPHER * a) const566   Score Evaluate(const SSL_CIPHER *a) const override {
567     return
568         // Something is always preferable to nothing.
569         1 +
570         // Either AES is fine, or else ChaCha20 is preferred.
571         ((aes_is_fine_ || a->algorithm_enc == SSL_CHACHA20POLY1305) ? 1 : 0);
572   }
573 
574  private:
575   const bool aes_is_fine_;
576 };
577 
578 // CNsaCipherScorer prefers AES-256-GCM over AES-128-GCM over anything else.
579 class CNsaCipherScorer : public CipherScorer {
580  public:
581   virtual ~CNsaCipherScorer() override = default;
582 
Evaluate(const SSL_CIPHER * a) const583   Score Evaluate(const SSL_CIPHER *a) const override {
584     if (a->id == TLS1_3_CK_AES_256_GCM_SHA384) {
585       return 3;
586     } else if (a->id == TLS1_3_CK_AES_128_GCM_SHA256) {
587       return 2;
588     }
589     return 1;
590   }
591 };
592 
593 }  // namespace
594 
ssl_tls13_cipher_meets_policy(uint16_t cipher_id,enum ssl_compliance_policy_t policy)595 bool ssl_tls13_cipher_meets_policy(uint16_t cipher_id,
596                                    enum ssl_compliance_policy_t policy) {
597   switch (policy) {
598     case ssl_compliance_policy_none:
599     case ssl_compliance_policy_cnsa_202407:
600       return true;
601 
602     case ssl_compliance_policy_fips_202205:
603       switch (cipher_id) {
604         case TLS1_3_CK_AES_128_GCM_SHA256 & 0xffff:
605         case TLS1_3_CK_AES_256_GCM_SHA384 & 0xffff:
606           return true;
607         case TLS1_3_CK_CHACHA20_POLY1305_SHA256 & 0xffff:
608           return false;
609         default:
610           assert(false);
611           return false;
612       }
613 
614     case ssl_compliance_policy_wpa3_192_202304:
615       switch (cipher_id) {
616         case TLS1_3_CK_AES_256_GCM_SHA384 & 0xffff:
617           return true;
618         case TLS1_3_CK_AES_128_GCM_SHA256 & 0xffff:
619         case TLS1_3_CK_CHACHA20_POLY1305_SHA256 & 0xffff:
620           return false;
621         default:
622           assert(false);
623           return false;
624       }
625   }
626 
627   assert(false);
628   return false;
629 }
630 
ssl_choose_tls13_cipher(CBS cipher_suites,bool has_aes_hw,uint16_t version,enum ssl_compliance_policy_t policy)631 const SSL_CIPHER *ssl_choose_tls13_cipher(CBS cipher_suites, bool has_aes_hw,
632                                           uint16_t version,
633                                           enum ssl_compliance_policy_t policy) {
634   if (CBS_len(&cipher_suites) % 2 != 0) {
635     return nullptr;
636   }
637 
638   const SSL_CIPHER *best = nullptr;
639   AesHwCipherScorer aes_hw_scorer(has_aes_hw);
640   CNsaCipherScorer cnsa_scorer;
641   CipherScorer *const scorer =
642       (policy == ssl_compliance_policy_cnsa_202407)
643           ? static_cast<CipherScorer *>(&cnsa_scorer)
644           : static_cast<CipherScorer *>(&aes_hw_scorer);
645   CipherScorer::Score best_score = CipherScorer::kMinScore;
646 
647   while (CBS_len(&cipher_suites) > 0) {
648     uint16_t cipher_suite;
649     if (!CBS_get_u16(&cipher_suites, &cipher_suite)) {
650       return nullptr;
651     }
652 
653     // Limit to TLS 1.3 ciphers we know about.
654     const SSL_CIPHER *candidate = SSL_get_cipher_by_value(cipher_suite);
655     if (candidate == nullptr ||
656         SSL_CIPHER_get_min_version(candidate) > version ||
657         SSL_CIPHER_get_max_version(candidate) < version) {
658       continue;
659     }
660 
661     if (!ssl_tls13_cipher_meets_policy(SSL_CIPHER_get_protocol_id(candidate),
662                                        policy)) {
663       continue;
664     }
665 
666     const CipherScorer::Score candidate_score = scorer->Evaluate(candidate);
667     // |candidate_score| must be larger to displace the current choice. That way
668     // the client's order controls between ciphers with an equal score.
669     if (candidate_score > best_score) {
670       best = candidate;
671       best_score = candidate_score;
672     }
673   }
674 
675   return best;
676 }
677 
678 BSSL_NAMESPACE_END
679