• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2005-2016 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the OpenSSL license (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <openssl/ssl.h>
11 
12 #include <assert.h>
13 #include <limits.h>
14 #include <string.h>
15 
16 #include <algorithm>
17 
18 #include <openssl/err.h>
19 #include <openssl/evp.h>
20 #include <openssl/mem.h>
21 #include <openssl/rand.h>
22 
23 #include "../crypto/internal.h"
24 #include "internal.h"
25 
26 
27 BSSL_NAMESPACE_BEGIN
28 
29 // TODO(davidben): 28 comes from the size of IP + UDP header. Is this reasonable
30 // for these values? Notably, why is kMinMTU a function of the transport
31 // protocol's overhead rather than, say, what's needed to hold a minimally-sized
32 // handshake fragment plus protocol overhead.
33 
34 // kMinMTU is the minimum acceptable MTU value.
35 static const unsigned int kMinMTU = 256 - 28;
36 
37 // kDefaultMTU is the default MTU value to use if neither the user nor
38 // the underlying BIO supplies one.
39 static const unsigned int kDefaultMTU = 1500 - 28;
40 
41 // BitRange returns a |uint8_t| with bits |start|, inclusive, to |end|,
42 // exclusive, set.
BitRange(size_t start,size_t end)43 static uint8_t BitRange(size_t start, size_t end) {
44   assert(start <= end && end <= 8);
45   return static_cast<uint8_t>(~((1u << start) - 1) & ((1u << end) - 1));
46 }
47 
48 // FirstUnmarkedRangeInByte returns the first unmarked range in bits |b|.
FirstUnmarkedRangeInByte(uint8_t b)49 static DTLSMessageBitmap::Range FirstUnmarkedRangeInByte(uint8_t b) {
50   size_t start, end;
51   for (start = 0; start < 8; start++) {
52     if ((b & (1u << start)) == 0) {
53       break;
54     }
55   }
56   for (end = start; end < 8; end++) {
57     if ((b & (1u << end)) != 0) {
58       break;
59     }
60   }
61   return DTLSMessageBitmap::Range{start, end};
62 }
63 
Init(size_t num_bits)64 bool DTLSMessageBitmap::Init(size_t num_bits) {
65   if (num_bits + 7 < num_bits) {
66     OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
67     return false;
68   }
69   size_t num_bytes = (num_bits + 7) / 8;
70   size_t bits_rounded = num_bytes * 8;
71   if (!bytes_.Init(num_bytes)) {
72     return false;
73   }
74   MarkRange(num_bits, bits_rounded);
75   first_unmarked_byte_ = 0;
76   return true;
77 }
78 
MarkRange(size_t start,size_t end)79 void DTLSMessageBitmap::MarkRange(size_t start, size_t end) {
80   assert(start <= end);
81   // Don't bother touching bytes that have already been marked.
82   start = std::max(start, first_unmarked_byte_ << 3);
83   // Clamp everything within range.
84   start = std::min(start, bytes_.size() << 3);
85   end = std::min(end, bytes_.size() << 3);
86   if (start >= end) {
87     return;
88   }
89 
90   if ((start >> 3) == (end >> 3)) {
91     bytes_[start >> 3] |= BitRange(start & 7, end & 7);
92   } else {
93     bytes_[start >> 3] |= BitRange(start & 7, 8);
94     for (size_t i = (start >> 3) + 1; i < (end >> 3); i++) {
95       bytes_[i] = 0xff;
96     }
97     if ((end & 7) != 0) {
98       bytes_[end >> 3] |= BitRange(0, end & 7);
99     }
100   }
101 
102   // Maintain the |first_unmarked_byte_| invariant. This work is amortized
103   // across all |MarkRange| calls.
104   while (first_unmarked_byte_ < bytes_.size() &&
105          bytes_[first_unmarked_byte_] == 0xff) {
106     first_unmarked_byte_++;
107   }
108   // If the whole message is marked, we no longer need to spend memory on the
109   // bitmap.
110   if (first_unmarked_byte_ >= bytes_.size()) {
111     bytes_.Reset();
112     first_unmarked_byte_ = 0;
113   }
114 }
115 
NextUnmarkedRange(size_t start) const116 DTLSMessageBitmap::Range DTLSMessageBitmap::NextUnmarkedRange(
117     size_t start) const {
118   // Don't bother looking at bytes that are known to be fully marked.
119   start = std::max(start, first_unmarked_byte_ << 3);
120 
121   size_t idx = start >> 3;
122   if (idx >= bytes_.size()) {
123     return Range{0, 0};
124   }
125 
126   // Look at the bits from |start| up to a byte boundary.
127   uint8_t byte = bytes_[idx] | BitRange(0, start & 7);
128   if (byte == 0xff) {
129     // Nothing unmarked at this byte. Keep searching for an unmarked bit.
130     for (idx = idx + 1; idx < bytes_.size(); idx++) {
131       if (bytes_[idx] != 0xff) {
132         byte = bytes_[idx];
133         break;
134       }
135     }
136     if (idx >= bytes_.size()) {
137       return Range{0, 0};
138     }
139   }
140 
141   Range range = FirstUnmarkedRangeInByte(byte);
142   assert(!range.empty());
143   bool should_extend = range.end == 8;
144   range.start += idx << 3;
145   range.end += idx << 3;
146   if (!should_extend) {
147     // The range did not end at a byte boundary. We're done.
148     return range;
149   }
150 
151   // Collect all fully unmarked bytes.
152   for (idx = idx + 1; idx < bytes_.size(); idx++) {
153     if (bytes_[idx] != 0) {
154       break;
155     }
156   }
157   range.end = idx << 3;
158 
159   // Add any bits from the remaining byte, if any.
160   if (idx < bytes_.size()) {
161     Range extra = FirstUnmarkedRangeInByte(bytes_[idx]);
162     if (extra.start == 0) {
163       range.end += extra.end;
164     }
165   }
166 
167   return range;
168 }
169 
170 // Receiving handshake messages.
171 
dtls_new_incoming_message(const struct hm_header_st * msg_hdr)172 static UniquePtr<DTLSIncomingMessage> dtls_new_incoming_message(
173     const struct hm_header_st *msg_hdr) {
174   ScopedCBB cbb;
175   UniquePtr<DTLSIncomingMessage> frag = MakeUnique<DTLSIncomingMessage>();
176   if (!frag) {
177     return nullptr;
178   }
179   frag->type = msg_hdr->type;
180   frag->seq = msg_hdr->seq;
181 
182   // Allocate space for the reassembled message and fill in the header.
183   if (!frag->data.InitForOverwrite(DTLS1_HM_HEADER_LENGTH + msg_hdr->msg_len)) {
184     return nullptr;
185   }
186 
187   if (!CBB_init_fixed(cbb.get(), frag->data.data(), DTLS1_HM_HEADER_LENGTH) ||
188       !CBB_add_u8(cbb.get(), msg_hdr->type) ||
189       !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
190       !CBB_add_u16(cbb.get(), msg_hdr->seq) ||
191       !CBB_add_u24(cbb.get(), 0 /* frag_off */) ||
192       !CBB_add_u24(cbb.get(), msg_hdr->msg_len) ||
193       !CBB_finish(cbb.get(), NULL, NULL)) {
194     return nullptr;
195   }
196 
197   if (!frag->reassembly.Init(msg_hdr->msg_len)) {
198     return nullptr;
199   }
200 
201   return frag;
202 }
203 
204 // dtls1_is_current_message_complete returns whether the current handshake
205 // message is complete.
dtls1_is_current_message_complete(const SSL * ssl)206 static bool dtls1_is_current_message_complete(const SSL *ssl) {
207   size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
208   DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
209   return frag != nullptr && frag->reassembly.IsComplete();
210 }
211 
212 // dtls1_get_incoming_message returns the incoming message corresponding to
213 // |msg_hdr|. If none exists, it creates a new one and inserts it in the
214 // queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
215 // returns NULL on failure. The caller does not take ownership of the result.
dtls1_get_incoming_message(SSL * ssl,uint8_t * out_alert,const struct hm_header_st * msg_hdr)216 static DTLSIncomingMessage *dtls1_get_incoming_message(
217     SSL *ssl, uint8_t *out_alert, const struct hm_header_st *msg_hdr) {
218   if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
219       msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
220     *out_alert = SSL_AD_INTERNAL_ERROR;
221     return NULL;
222   }
223 
224   size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
225   DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
226   if (frag != NULL) {
227     assert(frag->seq == msg_hdr->seq);
228     // The new fragment must be compatible with the previous fragments from this
229     // message.
230     if (frag->type != msg_hdr->type ||  //
231         frag->msg_len() != msg_hdr->msg_len) {
232       OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
233       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
234       return NULL;
235     }
236     return frag;
237   }
238 
239   // This is the first fragment from this message.
240   ssl->d1->incoming_messages[idx] = dtls_new_incoming_message(msg_hdr);
241   if (!ssl->d1->incoming_messages[idx]) {
242     *out_alert = SSL_AD_INTERNAL_ERROR;
243     return NULL;
244   }
245   return ssl->d1->incoming_messages[idx].get();
246 }
247 
dtls1_process_handshake_fragments(SSL * ssl,uint8_t * out_alert,DTLSRecordNumber record_number,Span<const uint8_t> record)248 bool dtls1_process_handshake_fragments(SSL *ssl, uint8_t *out_alert,
249                                        DTLSRecordNumber record_number,
250                                        Span<const uint8_t> record) {
251   bool implicit_ack = false;
252   bool skipped_fragments = false;
253   CBS cbs = record;
254   while (CBS_len(&cbs) > 0) {
255     // Read a handshake fragment.
256     struct hm_header_st msg_hdr;
257     CBS body;
258     if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
259       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
260       *out_alert = SSL_AD_DECODE_ERROR;
261       return false;
262     }
263 
264     const size_t frag_off = msg_hdr.frag_off;
265     const size_t frag_len = msg_hdr.frag_len;
266     const size_t msg_len = msg_hdr.msg_len;
267     if (frag_off > msg_len || frag_len > msg_len - frag_off) {
268       OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
269       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
270       return false;
271     }
272 
273     if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
274         ssl->d1->handshake_read_overflow) {
275       // Ignore fragments from the past. This is a retransmit of data we already
276       // received.
277       //
278       // TODO(crbug.com/42290594): Use this to drive retransmits.
279       continue;
280     }
281 
282     if (record_number.epoch() != ssl->d1->read_epoch.epoch ||
283         ssl->d1->next_read_epoch != nullptr) {
284       // New messages can only arrive in the latest epoch. This can fail if the
285       // record came from |prev_read_epoch|, or if it came from |read_epoch| but
286       // |next_read_epoch| exists. (It cannot come from |next_read_epoch|
287       // because |next_read_epoch| becomes |read_epoch| once it receives a
288       // record.)
289       OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESS_HANDSHAKE_DATA);
290       *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
291       return false;
292     }
293 
294     if (msg_len > ssl_max_handshake_message_len(ssl)) {
295       OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
296       *out_alert = SSL_AD_ILLEGAL_PARAMETER;
297       return false;
298     }
299 
300     if (SSL_in_init(ssl) && ssl_has_final_version(ssl) &&
301         ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
302       // During the handshake, if we receive any portion of the next flight, the
303       // peer must have received our most recent flight. In DTLS 1.3, this is an
304       // implicit ACK. See RFC 9147, Section 7.1.
305       //
306       // This only applies during the handshake. After the handshake, the next
307       // message may be part of a post-handshake transaction. It also does not
308       // apply immediately after the handshake. As a client, receiving a
309       // KeyUpdate or NewSessionTicket does not imply the server has received
310       // our Finished. The server may have sent those messages in half-RTT.
311       implicit_ack = true;
312     }
313 
314     if (msg_hdr.seq - ssl->d1->handshake_read_seq > SSL_MAX_HANDSHAKE_FLIGHT) {
315       // Ignore fragments too far in the future.
316       skipped_fragments = true;
317       continue;
318     }
319 
320     DTLSIncomingMessage *frag =
321         dtls1_get_incoming_message(ssl, out_alert, &msg_hdr);
322     if (frag == nullptr) {
323       return false;
324     }
325     assert(frag->msg_len() == msg_len);
326 
327     if (frag->reassembly.IsComplete()) {
328       // The message is already assembled.
329       continue;
330     }
331     assert(msg_len > 0);
332 
333     // Copy the body into the fragment.
334     Span<uint8_t> dest = frag->msg().subspan(frag_off, CBS_len(&body));
335     OPENSSL_memcpy(dest.data(), CBS_data(&body), CBS_len(&body));
336     frag->reassembly.MarkRange(frag_off, frag_off + frag_len);
337   }
338 
339   if (implicit_ack) {
340     dtls1_stop_timer(ssl);
341     dtls_clear_outgoing_messages(ssl);
342   }
343 
344   if (!skipped_fragments) {
345     ssl->d1->records_to_ack.PushBack(record_number);
346 
347     if (ssl_has_final_version(ssl) &&
348         ssl_protocol_version(ssl) >= TLS1_3_VERSION &&
349         !ssl->d1->ack_timer.IsSet() && !ssl->d1->sending_ack) {
350       // Schedule sending an ACK. The delay serves several purposes:
351       // - If there are more records to come, we send only one ACK.
352       // - If there are more records to come and the flight is now complete, we
353       //   will send the reply (which implicitly ACKs the previous flight) and
354       //   cancel the timer.
355       // - If there are more records to come, the flight is now complete, but
356       //   generating the response is delayed (e.g. a slow, async private key),
357       //   the timer will fire and we send an ACK anyway.
358       OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
359       ssl->d1->ack_timer.StartMicroseconds(
360           now, uint64_t{ssl->d1->timeout_duration_ms} * 1000 / 4);
361     }
362   }
363 
364   return true;
365 }
366 
dtls1_open_handshake(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)367 ssl_open_record_t dtls1_open_handshake(SSL *ssl, size_t *out_consumed,
368                                        uint8_t *out_alert, Span<uint8_t> in) {
369   uint8_t type;
370   DTLSRecordNumber record_number;
371   Span<uint8_t> record;
372   auto ret = dtls_open_record(ssl, &type, &record_number, &record, out_consumed,
373                               out_alert, in);
374   if (ret != ssl_open_record_success) {
375     return ret;
376   }
377 
378   switch (type) {
379     case SSL3_RT_APPLICATION_DATA:
380       // In DTLS 1.2, out-of-order application data may be received between
381       // ChangeCipherSpec and Finished. Discard it.
382       return ssl_open_record_discard;
383 
384     case SSL3_RT_CHANGE_CIPHER_SPEC:
385       if (record.size() != 1u || record[0] != SSL3_MT_CCS) {
386         OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_CHANGE_CIPHER_SPEC);
387         *out_alert = SSL_AD_ILLEGAL_PARAMETER;
388         return ssl_open_record_error;
389       }
390 
391       // We do not support renegotiation, so encrypted ChangeCipherSpec records
392       // are illegal.
393       if (record_number.epoch() != 0) {
394         OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
395         *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
396         return ssl_open_record_error;
397       }
398 
399       // Ignore ChangeCipherSpec from a previous epoch.
400       if (record_number.epoch() != ssl->d1->read_epoch.epoch) {
401         return ssl_open_record_discard;
402       }
403 
404       // Flag the ChangeCipherSpec for later.
405       // TODO(crbug.com/42290594): Should we reject this in DTLS 1.3?
406       ssl->d1->has_change_cipher_spec = true;
407       ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_CHANGE_CIPHER_SPEC,
408                           record);
409       return ssl_open_record_success;
410 
411     case SSL3_RT_ACK:
412       return dtls1_process_ack(ssl, out_alert, record_number, record);
413 
414     case SSL3_RT_HANDSHAKE:
415       if (!dtls1_process_handshake_fragments(ssl, out_alert, record_number,
416                                              record)) {
417         return ssl_open_record_error;
418       }
419       return ssl_open_record_success;
420 
421     default:
422       OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
423       *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
424       return ssl_open_record_error;
425   }
426 }
427 
dtls1_get_message(const SSL * ssl,SSLMessage * out)428 bool dtls1_get_message(const SSL *ssl, SSLMessage *out) {
429   if (!dtls1_is_current_message_complete(ssl)) {
430     return false;
431   }
432 
433   size_t idx = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
434   const DTLSIncomingMessage *frag = ssl->d1->incoming_messages[idx].get();
435   out->type = frag->type;
436   out->raw = CBS(frag->data);
437   out->body = CBS(frag->msg());
438   out->is_v2_hello = false;
439   if (!ssl->s3->has_message) {
440     ssl_do_msg_callback(ssl, 0 /* read */, SSL3_RT_HANDSHAKE, out->raw);
441     ssl->s3->has_message = true;
442   }
443   return true;
444 }
445 
dtls1_next_message(SSL * ssl)446 void dtls1_next_message(SSL *ssl) {
447   assert(ssl->s3->has_message);
448   assert(dtls1_is_current_message_complete(ssl));
449   size_t index = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
450   ssl->d1->incoming_messages[index].reset();
451   ssl->d1->handshake_read_seq++;
452   if (ssl->d1->handshake_read_seq == 0) {
453     ssl->d1->handshake_read_overflow = true;
454   }
455   ssl->s3->has_message = false;
456   // If we previously sent a flight, mark it as having a reply, so
457   // |on_handshake_complete| can manage post-handshake retransmission.
458   if (ssl->d1->outgoing_messages_complete) {
459     ssl->d1->flight_has_reply = true;
460   }
461 }
462 
dtls_has_unprocessed_handshake_data(const SSL * ssl)463 bool dtls_has_unprocessed_handshake_data(const SSL *ssl) {
464   size_t current = ssl->d1->handshake_read_seq % SSL_MAX_HANDSHAKE_FLIGHT;
465   for (size_t i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
466     // Skip the current message.
467     if (ssl->s3->has_message && i == current) {
468       assert(dtls1_is_current_message_complete(ssl));
469       continue;
470     }
471     if (ssl->d1->incoming_messages[i] != nullptr) {
472       return true;
473     }
474   }
475   return false;
476 }
477 
dtls1_parse_fragment(CBS * cbs,struct hm_header_st * out_hdr,CBS * out_body)478 bool dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
479                           CBS *out_body) {
480   OPENSSL_memset(out_hdr, 0x00, sizeof(struct hm_header_st));
481 
482   if (!CBS_get_u8(cbs, &out_hdr->type) ||
483       !CBS_get_u24(cbs, &out_hdr->msg_len) ||
484       !CBS_get_u16(cbs, &out_hdr->seq) ||
485       !CBS_get_u24(cbs, &out_hdr->frag_off) ||
486       !CBS_get_u24(cbs, &out_hdr->frag_len) ||
487       !CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) {
488     return false;
489   }
490 
491   return true;
492 }
493 
dtls1_open_change_cipher_spec(SSL * ssl,size_t * out_consumed,uint8_t * out_alert,Span<uint8_t> in)494 ssl_open_record_t dtls1_open_change_cipher_spec(SSL *ssl, size_t *out_consumed,
495                                                 uint8_t *out_alert,
496                                                 Span<uint8_t> in) {
497   if (!ssl->d1->has_change_cipher_spec) {
498     // dtls1_open_handshake processes both handshake and ChangeCipherSpec.
499     auto ret = dtls1_open_handshake(ssl, out_consumed, out_alert, in);
500     if (ret != ssl_open_record_success) {
501       return ret;
502     }
503   }
504   if (ssl->d1->has_change_cipher_spec) {
505     ssl->d1->has_change_cipher_spec = false;
506     return ssl_open_record_success;
507   }
508   return ssl_open_record_discard;
509 }
510 
511 
512 // Sending handshake messages.
513 
dtls_clear_outgoing_messages(SSL * ssl)514 void dtls_clear_outgoing_messages(SSL *ssl) {
515   ssl->d1->outgoing_messages.clear();
516   ssl->d1->sent_records = nullptr;
517   ssl->d1->outgoing_written = 0;
518   ssl->d1->outgoing_offset = 0;
519   ssl->d1->outgoing_messages_complete = false;
520   ssl->d1->flight_has_reply = false;
521   ssl->d1->sending_flight = false;
522   dtls_clear_unused_write_epochs(ssl);
523 }
524 
dtls_clear_unused_write_epochs(SSL * ssl)525 void dtls_clear_unused_write_epochs(SSL *ssl) {
526   ssl->d1->extra_write_epochs.EraseIf(
527       [ssl](const UniquePtr<DTLSWriteEpoch> &write_epoch) -> bool {
528         // Non-current epochs may be discarded once there are no incomplete
529         // outgoing messages that reference them.
530         //
531         // TODO(crbug.com/42290594): Epoch 1 (0-RTT) should be retained until
532         // epoch 3 (app data) is available.
533         for (const auto &msg : ssl->d1->outgoing_messages) {
534           if (msg.epoch == write_epoch->epoch() && !msg.IsFullyAcked()) {
535             return false;
536           }
537         }
538         return true;
539       });
540 }
541 
dtls1_init_message(const SSL * ssl,CBB * cbb,CBB * body,uint8_t type)542 bool dtls1_init_message(const SSL *ssl, CBB *cbb, CBB *body, uint8_t type) {
543   // Pick a modest size hint to save most of the |realloc| calls.
544   if (!CBB_init(cbb, 64) ||                                   //
545       !CBB_add_u8(cbb, type) ||                               //
546       !CBB_add_u24(cbb, 0 /* length (filled in later) */) ||  //
547       !CBB_add_u16(cbb, ssl->d1->handshake_write_seq) ||      //
548       !CBB_add_u24(cbb, 0 /* offset */) ||                    //
549       !CBB_add_u24_length_prefixed(cbb, body)) {
550     return false;
551   }
552 
553   return true;
554 }
555 
dtls1_finish_message(const SSL * ssl,CBB * cbb,Array<uint8_t> * out_msg)556 bool dtls1_finish_message(const SSL *ssl, CBB *cbb, Array<uint8_t> *out_msg) {
557   if (!CBBFinishArray(cbb, out_msg) ||
558       out_msg->size() < DTLS1_HM_HEADER_LENGTH) {
559     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
560     return false;
561   }
562 
563   // Fix up the header. Copy the fragment length into the total message
564   // length.
565   OPENSSL_memcpy(out_msg->data() + 1,
566                  out_msg->data() + DTLS1_HM_HEADER_LENGTH - 3, 3);
567   return true;
568 }
569 
570 // add_outgoing adds a new handshake message or ChangeCipherSpec to the current
571 // outgoing flight. It returns true on success and false on error.
add_outgoing(SSL * ssl,bool is_ccs,Array<uint8_t> data)572 static bool add_outgoing(SSL *ssl, bool is_ccs, Array<uint8_t> data) {
573   if (ssl->d1->outgoing_messages_complete) {
574     // If we've begun writing a new flight, we received the peer flight. Discard
575     // the timer and the our flight.
576     dtls1_stop_timer(ssl);
577     dtls_clear_outgoing_messages(ssl);
578   }
579 
580   if (!is_ccs) {
581     if (ssl->d1->handshake_write_overflow) {
582       OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
583       return false;
584     }
585     // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript
586     // on hs.
587     if (ssl->s3->hs != NULL && !ssl->s3->hs->transcript.Update(data)) {
588       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
589       return false;
590     }
591     ssl->d1->handshake_write_seq++;
592     if (ssl->d1->handshake_write_seq == 0) {
593       ssl->d1->handshake_write_overflow = true;
594     }
595   }
596 
597   DTLSOutgoingMessage msg;
598   msg.data = std::move(data);
599   msg.epoch = ssl->d1->write_epoch.epoch();
600   msg.is_ccs = is_ccs;
601   // Zero-length messages need 1 bit to track whether the peer has received the
602   // message header. (Normally the message header is implicitly received when
603   // any fragment of the message is received at all.)
604   if (!is_ccs && !msg.acked.Init(std::max(msg.msg_len(), size_t{1}))) {
605     return false;
606   }
607 
608   // This should not fail if |SSL_MAX_HANDSHAKE_FLIGHT| was sized correctly.
609   //
610   // TODO(crbug.com/42290594): This can currently fail in DTLS 1.3. The caller
611   // can configure how many tickets to send, up to kMaxTickets. Additionally, if
612   // we send 0.5-RTT tickets in 0-RTT, we may even have tickets queued up with
613   // the server flight.
614   if (!ssl->d1->outgoing_messages.TryPushBack(std::move(msg))) {
615     assert(false);
616     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
617     return false;
618   }
619 
620   return true;
621 }
622 
dtls1_add_message(SSL * ssl,Array<uint8_t> data)623 bool dtls1_add_message(SSL *ssl, Array<uint8_t> data) {
624   return add_outgoing(ssl, false /* handshake */, std::move(data));
625 }
626 
dtls1_add_change_cipher_spec(SSL * ssl)627 bool dtls1_add_change_cipher_spec(SSL *ssl) {
628   // DTLS 1.3 disables compatibility mode, which means that DTLS 1.3 never sends
629   // a ChangeCipherSpec message.
630   if (ssl_protocol_version(ssl) > TLS1_2_VERSION) {
631     return true;
632   }
633   return add_outgoing(ssl, true /* ChangeCipherSpec */, Array<uint8_t>());
634 }
635 
636 // dtls1_update_mtu updates the current MTU from the BIO, ensuring it is above
637 // the minimum.
dtls1_update_mtu(SSL * ssl)638 static void dtls1_update_mtu(SSL *ssl) {
639   // TODO(davidben): No consumer implements |BIO_CTRL_DGRAM_SET_MTU| and the
640   // only |BIO_CTRL_DGRAM_QUERY_MTU| implementation could use
641   // |SSL_set_mtu|. Does this need to be so complex?
642   if (ssl->d1->mtu < dtls1_min_mtu() &&
643       !(SSL_get_options(ssl) & SSL_OP_NO_QUERY_MTU)) {
644     long mtu = BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_QUERY_MTU, 0, NULL);
645     if (mtu >= 0 && mtu <= (1 << 30) && (unsigned)mtu >= dtls1_min_mtu()) {
646       ssl->d1->mtu = (unsigned)mtu;
647     } else {
648       ssl->d1->mtu = kDefaultMTU;
649       BIO_ctrl(ssl->wbio.get(), BIO_CTRL_DGRAM_SET_MTU, ssl->d1->mtu, NULL);
650     }
651   }
652 
653   // The MTU should be above the minimum now.
654   assert(ssl->d1->mtu >= dtls1_min_mtu());
655 }
656 
657 enum seal_result_t {
658   seal_error,
659   seal_continue,
660   seal_flush,
661 };
662 
663 // seal_next_record seals one record's worth of messages to |out| and advances
664 // |ssl|'s internal state past the data that was sealed. If progress was made,
665 // it returns |seal_flush| or |seal_continue| and sets
666 // |*out_len| to the number of bytes written.
667 //
668 // If the function stopped because the next message could not be combined into
669 // this record, it returns |seal_continue| and the caller should loop again.
670 // Otherwise, it returns |seal_flush| and the packet is complete (either because
671 // there are no more messages or the packet is full).
seal_next_record(SSL * ssl,Span<uint8_t> out,size_t * out_len)672 static seal_result_t seal_next_record(SSL *ssl, Span<uint8_t> out,
673                                       size_t *out_len) {
674   *out_len = 0;
675 
676   // Skip any fully acked messages.
677   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size() &&
678          ssl->d1->outgoing_messages[ssl->d1->outgoing_written].IsFullyAcked()) {
679     ssl->d1->outgoing_offset = 0;
680     ssl->d1->outgoing_written++;
681   }
682 
683   // There was nothing left to write.
684   if (ssl->d1->outgoing_written >= ssl->d1->outgoing_messages.size()) {
685     return seal_flush;
686   }
687 
688   const auto &first_msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
689   size_t prefix_len = dtls_seal_prefix_len(ssl, first_msg.epoch);
690   size_t max_in_len = dtls_seal_max_input_len(ssl, first_msg.epoch, out.size());
691   if (max_in_len == 0) {
692     // There is no room for a single record.
693     return seal_flush;
694   }
695 
696   if (first_msg.is_ccs) {
697     static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS};
698     DTLSRecordNumber record_number;
699     if (!dtls_seal_record(ssl, &record_number, out.data(), out_len, out.size(),
700                           SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec,
701                           sizeof(kChangeCipherSpec), first_msg.epoch)) {
702       return seal_error;
703     }
704 
705     ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_CHANGE_CIPHER_SPEC,
706                         kChangeCipherSpec);
707     ssl->d1->outgoing_offset = 0;
708     ssl->d1->outgoing_written++;
709     return seal_continue;
710   }
711 
712   // TODO(crbug.com/374991962): For now, only send one message per record in
713   // epoch 0. Sending multiple is allowed and more efficient, but breaks
714   // b/378742138.
715   const bool allow_multiple_messages = first_msg.epoch != 0;
716 
717   // Pack as many handshake fragments into one record as we can. We stage the
718   // fragments in the output buffer, to be sealed in-place.
719   bool should_continue = false;
720   Span<uint8_t> fragments = out.subspan(prefix_len, max_in_len);
721   CBB cbb;
722   CBB_init_fixed(&cbb, fragments.data(), fragments.size());
723   DTLSSentRecord sent_record;
724   sent_record.first_msg = ssl->d1->outgoing_written;
725   sent_record.first_msg_start = ssl->d1->outgoing_offset;
726   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
727     const auto &msg = ssl->d1->outgoing_messages[ssl->d1->outgoing_written];
728     if (msg.epoch != first_msg.epoch || msg.is_ccs) {
729       // We can only pack messages if the epoch matches. There may be more room
730       // in the packet, so tell the caller to keep going.
731       should_continue = true;
732       break;
733     }
734 
735     // Decode |msg|'s header.
736     CBS cbs(msg.data), body_cbs;
737     struct hm_header_st hdr;
738     if (!dtls1_parse_fragment(&cbs, &hdr, &body_cbs) ||  //
739         hdr.frag_off != 0 ||                             //
740         hdr.frag_len != CBS_len(&body_cbs) ||            //
741         hdr.msg_len != CBS_len(&body_cbs) ||             //
742         CBS_len(&cbs) != 0) {
743       OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
744       return seal_error;
745     }
746 
747     // Iterate over every un-acked range in the message, if any.
748     Span<const uint8_t> body = body_cbs;
749     for (;;) {
750       auto range = msg.acked.NextUnmarkedRange(ssl->d1->outgoing_offset);
751       if (range.empty()) {
752         // Advance to the next message.
753         ssl->d1->outgoing_offset = 0;
754         ssl->d1->outgoing_written++;
755         break;
756       }
757 
758       // Determine how much progress can be made (minimum one byte of progress).
759       size_t capacity = fragments.size() - CBB_len(&cbb);
760       if (capacity < DTLS1_HM_HEADER_LENGTH + 1) {
761         goto packet_full;
762       }
763       size_t todo = std::min(range.size(), capacity - DTLS1_HM_HEADER_LENGTH);
764 
765       // Empty messages are special-cased in ACK tracking. We act as if they
766       // have one byte, but in reality that byte is tracking the header.
767       Span<const uint8_t> frag;
768       if (!body.empty()) {
769         frag = body.subspan(range.start, todo);
770       }
771 
772       // Assemble the fragment.
773       size_t frag_start = CBB_len(&cbb);
774       CBB child;
775       if (!CBB_add_u8(&cbb, hdr.type) ||                       //
776           !CBB_add_u24(&cbb, hdr.msg_len) ||                   //
777           !CBB_add_u16(&cbb, hdr.seq) ||                       //
778           !CBB_add_u24(&cbb, range.start) ||                   //
779           !CBB_add_u24_length_prefixed(&cbb, &child) ||        //
780           !CBB_add_bytes(&child, frag.data(), frag.size()) ||  //
781           !CBB_flush(&cbb)) {
782         OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
783         return seal_error;
784       }
785       size_t frag_end = CBB_len(&cbb);
786 
787       // TODO(davidben): It is odd that, on output, we inform the caller of
788       // retransmits and individual fragments, but on input we only inform the
789       // caller of complete messages.
790       ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_HANDSHAKE,
791                           fragments.subspan(frag_start, frag_end - frag_start));
792 
793       ssl->d1->outgoing_offset = range.start + todo;
794       if (todo < range.size()) {
795         // The packet was the limiting factor.
796         goto packet_full;
797       }
798     }
799 
800     if (!allow_multiple_messages) {
801       should_continue = true;
802       break;
803     }
804   }
805 
806 packet_full:
807   sent_record.last_msg = ssl->d1->outgoing_written;
808   sent_record.last_msg_end = ssl->d1->outgoing_offset;
809 
810   // We could not fit anything. Don't try to make a record.
811   if (CBB_len(&cbb) == 0) {
812     assert(!should_continue);
813     return seal_flush;
814   }
815 
816   if (!dtls_seal_record(ssl, &sent_record.number, out.data(), out_len,
817                         out.size(), SSL3_RT_HANDSHAKE, CBB_data(&cbb),
818                         CBB_len(&cbb), first_msg.epoch)) {
819     return seal_error;
820   }
821 
822   // If DTLS 1.3 (or if the version is not yet known and it may be DTLS 1.3),
823   // save the record number to match against ACKs later.
824   if (ssl->s3->version == 0 || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
825     if (ssl->d1->sent_records == nullptr) {
826       ssl->d1->sent_records =
827           MakeUnique<MRUQueue<DTLSSentRecord, DTLS_MAX_ACK_BUFFER>>();
828       if (ssl->d1->sent_records == nullptr) {
829         return seal_error;
830       }
831     }
832     ssl->d1->sent_records->PushBack(sent_record);
833   }
834 
835   return should_continue ? seal_continue : seal_flush;
836 }
837 
838 // seal_next_packet writes as much of the next flight as possible to |out| and
839 // advances |ssl->d1->outgoing_written| and |ssl->d1->outgoing_offset| as
840 // appropriate.
seal_next_packet(SSL * ssl,Span<uint8_t> out,size_t * out_len)841 static bool seal_next_packet(SSL *ssl, Span<uint8_t> out, size_t *out_len) {
842   size_t total = 0;
843   for (;;) {
844     size_t len;
845     seal_result_t ret = seal_next_record(ssl, out, &len);
846     switch (ret) {
847       case seal_error:
848         return false;
849 
850       case seal_flush:
851       case seal_continue:
852         out = out.subspan(len);
853         total += len;
854         break;
855     }
856 
857     if (ret == seal_flush) {
858       break;
859     }
860   }
861 
862   *out_len = total;
863   return true;
864 }
865 
send_flight(SSL * ssl)866 static int send_flight(SSL *ssl) {
867   if (ssl->s3->write_shutdown != ssl_shutdown_none) {
868     OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN);
869     return -1;
870   }
871 
872   if (ssl->wbio == nullptr) {
873     OPENSSL_PUT_ERROR(SSL, SSL_R_BIO_NOT_SET);
874     return -1;
875   }
876 
877   if (ssl->d1->num_timeouts > DTLS1_MAX_TIMEOUTS) {
878     OPENSSL_PUT_ERROR(SSL, SSL_R_READ_TIMEOUT_EXPIRED);
879     return -1;
880   }
881 
882   dtls1_update_mtu(ssl);
883 
884   Array<uint8_t> packet;
885   if (!packet.InitForOverwrite(ssl->d1->mtu)) {
886     return -1;
887   }
888 
889   while (ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
890     uint8_t old_written = ssl->d1->outgoing_written;
891     uint32_t old_offset = ssl->d1->outgoing_offset;
892 
893     size_t packet_len;
894     if (!seal_next_packet(ssl, Span(packet), &packet_len)) {
895       return -1;
896     }
897 
898     if (packet_len == 0 &&
899         ssl->d1->outgoing_written < ssl->d1->outgoing_messages.size()) {
900       // We made no progress with the packet size available, but did not reach
901       // the end.
902       OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
903       return false;
904     }
905 
906     if (packet_len != 0) {
907       int bio_ret = BIO_write(ssl->wbio.get(), packet.data(), packet_len);
908       if (bio_ret <= 0) {
909         // Retry this packet the next time around.
910         ssl->d1->outgoing_written = old_written;
911         ssl->d1->outgoing_offset = old_offset;
912         ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
913         return bio_ret;
914       }
915     }
916   }
917 
918   if (BIO_flush(ssl->wbio.get()) <= 0) {
919     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
920     return -1;
921   }
922 
923   return 1;
924 }
925 
dtls1_finish_flight(SSL * ssl)926 void dtls1_finish_flight(SSL *ssl) {
927   if (ssl->d1->outgoing_messages.empty() ||
928       ssl->d1->outgoing_messages_complete) {
929     return;  // Nothing to do.
930   }
931 
932   if (ssl->d1->outgoing_messages[0].epoch <= 2) {
933     // DTLS 1.3 handshake messages (epoch 2 and below) implicitly ACK the
934     // previous flight, so there is no need to ACK previous records. This
935     // clears the ACK buffer slightly earlier than the specification suggests.
936     // See the discussion in
937     // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/
938     ssl->d1->records_to_ack.Clear();
939     ssl->d1->ack_timer.Stop();
940     ssl->d1->sending_ack = false;
941   }
942 
943   ssl->d1->outgoing_messages_complete = true;
944   ssl->d1->sending_flight = true;
945   // Stop retransmitting the previous flight. In DTLS 1.3, we'll have stopped
946   // the timer already, but DTLS 1.2 keeps it running until the next flight is
947   // ready.
948   dtls1_stop_timer(ssl);
949 }
950 
dtls1_schedule_ack(SSL * ssl)951 void dtls1_schedule_ack(SSL *ssl) {
952   ssl->d1->ack_timer.Stop();
953   ssl->d1->sending_ack = !ssl->d1->records_to_ack.empty();
954 }
955 
send_ack(SSL * ssl)956 static int send_ack(SSL *ssl) {
957   assert(ssl_protocol_version(ssl) >= TLS1_3_VERSION);
958 
959   // Ensure we don't send so many ACKs that we overflow the MTU. There is a
960   // 2-byte length prefix and each ACK is 16 bytes.
961   dtls1_update_mtu(ssl);
962   size_t max_plaintext =
963       dtls_seal_max_input_len(ssl, ssl->d1->write_epoch.epoch(), ssl->d1->mtu);
964   if (max_plaintext < 2 + 16) {
965     OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);  // No room for even one ACK.
966     return -1;
967   }
968   size_t num_acks =
969       std::min((max_plaintext - 2) / 16, ssl->d1->records_to_ack.size());
970 
971   // Assemble the ACK. RFC 9147 says to sort ACKs numerically. It is unclear if
972   // other implementations do this, but go ahead and sort for now. See
973   // https://mailarchive.ietf.org/arch/msg/tls/kjJnquJOVaWxu5hUCmNzB35eqY0/.
974   // Remove this if rfc9147bis removes this requirement.
975   InplaceVector<DTLSRecordNumber, DTLS_MAX_ACK_BUFFER> sorted;
976   for (size_t i = ssl->d1->records_to_ack.size() - num_acks;
977        i < ssl->d1->records_to_ack.size(); i++) {
978     sorted.PushBack(ssl->d1->records_to_ack[i]);
979   }
980   std::sort(sorted.begin(), sorted.end());
981 
982   uint8_t buf[2 + 16 * DTLS_MAX_ACK_BUFFER];
983   CBB cbb, child;
984   CBB_init_fixed(&cbb, buf, sizeof(buf));
985   BSSL_CHECK(CBB_add_u16_length_prefixed(&cbb, &child));
986   for (const auto &number : sorted) {
987     BSSL_CHECK(CBB_add_u64(&child, number.epoch()));
988     BSSL_CHECK(CBB_add_u64(&child, number.sequence()));
989   }
990   BSSL_CHECK(CBB_flush(&cbb));
991 
992   // Encrypt it.
993   uint8_t record[DTLS1_3_RECORD_HEADER_WRITE_LENGTH + sizeof(buf) +
994                  1 /* record type */ + EVP_AEAD_MAX_OVERHEAD];
995   size_t record_len;
996   DTLSRecordNumber record_number;
997   if (!dtls_seal_record(ssl, &record_number, record, &record_len,
998                         sizeof(record), SSL3_RT_ACK, CBB_data(&cbb),
999                         CBB_len(&cbb), ssl->d1->write_epoch.epoch())) {
1000     return -1;
1001   }
1002 
1003   ssl_do_msg_callback(ssl, /*is_write=*/1, SSL3_RT_ACK,
1004                       Span(CBB_data(&cbb), CBB_len(&cbb)));
1005 
1006   int bio_ret =
1007       BIO_write(ssl->wbio.get(), record, static_cast<int>(record_len));
1008   if (bio_ret <= 0) {
1009     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1010     return bio_ret;
1011   }
1012 
1013   if (BIO_flush(ssl->wbio.get()) <= 0) {
1014     ssl->s3->rwstate = SSL_ERROR_WANT_WRITE;
1015     return -1;
1016   }
1017 
1018   return 1;
1019 }
1020 
dtls1_flush(SSL * ssl)1021 int dtls1_flush(SSL *ssl) {
1022   // Send the pending ACK, if any.
1023   if (ssl->d1->sending_ack) {
1024     int ret = send_ack(ssl);
1025     if (ret <= 0) {
1026       return ret;
1027     }
1028     ssl->d1->sending_ack = false;
1029   }
1030 
1031   // Send the pending flight, if any.
1032   if (ssl->d1->sending_flight) {
1033     int ret = send_flight(ssl);
1034     if (ret <= 0) {
1035       return ret;
1036     }
1037 
1038     // Reset state for the next send.
1039     ssl->d1->outgoing_written = 0;
1040     ssl->d1->outgoing_offset = 0;
1041     ssl->d1->sending_flight = false;
1042 
1043     // Schedule the next retransmit timer. In DTLS 1.3, we retransmit all
1044     // flights until ACKed. In DTLS 1.2, the final Finished flight is never
1045     // ACKed, so we do not keep the timer running after the handshake.
1046     if (SSL_in_init(ssl) || ssl_protocol_version(ssl) >= TLS1_3_VERSION) {
1047       if (ssl->d1->num_timeouts == 0) {
1048         ssl->d1->timeout_duration_ms = ssl->initial_timeout_duration_ms;
1049       } else {
1050         ssl->d1->timeout_duration_ms =
1051             std::min(ssl->d1->timeout_duration_ms * 2, uint32_t{60000});
1052       }
1053 
1054       OPENSSL_timeval now = ssl_ctx_get_current_time(ssl->ctx.get());
1055       ssl->d1->retransmit_timer.StartMicroseconds(
1056           now, uint64_t{ssl->d1->timeout_duration_ms} * 1000);
1057     }
1058   }
1059 
1060   return 1;
1061 }
1062 
dtls1_min_mtu(void)1063 unsigned int dtls1_min_mtu(void) { return kMinMTU; }
1064 
1065 BSSL_NAMESPACE_END
1066