1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5 /*
6 * DTLS Protocol
7 */
8
9 #include "ssl.h"
10 #include "sslimpl.h"
11 #include "sslproto.h"
12
13 #ifndef PR_ARRAY_SIZE
14 #define PR_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
15 #endif
16
17 static SECStatus dtls_TransmitMessageFlight(sslSocket *ss);
18 static void dtls_RetransmitTimerExpiredCb(sslSocket *ss);
19 static SECStatus dtls_SendSavedWriteData(sslSocket *ss);
20
21 /* -28 adjusts for the IP/UDP header */
22 static const PRUint16 COMMON_MTU_VALUES[] = {
23 1500 - 28, /* Ethernet MTU */
24 1280 - 28, /* IPv6 minimum MTU */
25 576 - 28, /* Common assumption */
26 256 - 28 /* We're in serious trouble now */
27 };
28
29 #define DTLS_COOKIE_BYTES 32
30
31 /* List copied from ssl3con.c:cipherSuites */
32 static const ssl3CipherSuite nonDTLSSuites[] = {
33 #ifdef NSS_ENABLE_ECC
34 TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
35 TLS_ECDHE_RSA_WITH_RC4_128_SHA,
36 #endif /* NSS_ENABLE_ECC */
37 TLS_DHE_DSS_WITH_RC4_128_SHA,
38 #ifdef NSS_ENABLE_ECC
39 TLS_ECDH_RSA_WITH_RC4_128_SHA,
40 TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
41 #endif /* NSS_ENABLE_ECC */
42 SSL_RSA_WITH_RC4_128_MD5,
43 SSL_RSA_WITH_RC4_128_SHA,
44 TLS_RSA_EXPORT1024_WITH_RC4_56_SHA,
45 SSL_RSA_EXPORT_WITH_RC4_40_MD5,
46 0 /* End of list marker */
47 };
48
49 /* Map back and forth between TLS and DTLS versions in wire format.
50 * Mapping table is:
51 *
52 * TLS DTLS
53 * 1.1 (0302) 1.0 (feff)
54 */
55 SSL3ProtocolVersion
dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)56 dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)
57 {
58 /* Anything other than TLS 1.1 is an error, so return
59 * the invalid version ffff. */
60 if (tlsv != SSL_LIBRARY_VERSION_TLS_1_1)
61 return 0xffff;
62
63 return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
64 }
65
66 /* Map known DTLS versions to known TLS versions.
67 * - Invalid versions (< 1.0) return a version of 0
68 * - Versions > known return a version one higher than we know of
69 * to accomodate a theoretically newer version */
70 SSL3ProtocolVersion
dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)71 dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)
72 {
73 if (MSB(dtlsv) == 0xff) {
74 return 0;
75 }
76
77 if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE)
78 return SSL_LIBRARY_VERSION_TLS_1_1;
79
80 /* Return a fictional higher version than we know of */
81 return SSL_LIBRARY_VERSION_TLS_1_1 + 1;
82 }
83
84 /* On this socket, Disable non-DTLS cipher suites in the argument's list */
85 SECStatus
ssl3_DisableNonDTLSSuites(sslSocket * ss)86 ssl3_DisableNonDTLSSuites(sslSocket * ss)
87 {
88 const ssl3CipherSuite * suite;
89
90 for (suite = nonDTLSSuites; *suite; ++suite) {
91 SECStatus rv = ssl3_CipherPrefSet(ss, *suite, PR_FALSE);
92
93 PORT_Assert(rv == SECSuccess); /* else is coding error */
94 }
95 return SECSuccess;
96 }
97
98 /* Allocate a DTLSQueuedMessage.
99 *
100 * Called from dtls_QueueMessage()
101 */
102 static DTLSQueuedMessage *
dtls_AllocQueuedMessage(PRUint16 epoch,SSL3ContentType type,const unsigned char * data,PRUint32 len)103 dtls_AllocQueuedMessage(PRUint16 epoch, SSL3ContentType type,
104 const unsigned char *data, PRUint32 len)
105 {
106 DTLSQueuedMessage *msg = NULL;
107
108 msg = PORT_ZAlloc(sizeof(DTLSQueuedMessage));
109 if (!msg)
110 return NULL;
111
112 msg->data = PORT_Alloc(len);
113 if (!msg->data) {
114 PORT_Free(msg);
115 return NULL;
116 }
117 PORT_Memcpy(msg->data, data, len);
118
119 msg->len = len;
120 msg->epoch = epoch;
121 msg->type = type;
122
123 return msg;
124 }
125
126 /*
127 * Free a handshake message
128 *
129 * Called from dtls_FreeHandshakeMessages()
130 */
131 static void
dtls_FreeHandshakeMessage(DTLSQueuedMessage * msg)132 dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg)
133 {
134 if (!msg)
135 return;
136
137 PORT_ZFree(msg->data, msg->len);
138 PORT_Free(msg);
139 }
140
141 /*
142 * Free a list of handshake messages
143 *
144 * Called from:
145 * dtls_HandleHandshake()
146 * ssl3_DestroySSL3Info()
147 */
148 void
dtls_FreeHandshakeMessages(PRCList * list)149 dtls_FreeHandshakeMessages(PRCList *list)
150 {
151 PRCList *cur_p;
152
153 while (!PR_CLIST_IS_EMPTY(list)) {
154 cur_p = PR_LIST_TAIL(list);
155 PR_REMOVE_LINK(cur_p);
156 dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p);
157 }
158 }
159
160 /* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
161 * origBuf is the decrypted ssl record content and is expected to contain
162 * complete handshake records
163 * Caller must hold the handshake and RecvBuf locks.
164 *
165 * Note that this code uses msg_len for two purposes:
166 *
167 * (1) To pass the length to ssl3_HandleHandshakeMessage()
168 * (2) To carry the length of a message currently being reassembled
169 *
170 * However, unlike ssl3_HandleHandshake(), it is not used to carry
171 * the state of reassembly (i.e., whether one is in progress). That
172 * is carried in recvdHighWater and recvdFragments.
173 */
174 #define OFFSET_BYTE(o) (o/8)
175 #define OFFSET_MASK(o) (1 << (o%8))
176
177 SECStatus
dtls_HandleHandshake(sslSocket * ss,sslBuffer * origBuf)178 dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
179 {
180 /* XXX OK for now.
181 * This doesn't work properly with asynchronous certificate validation.
182 * because that returns a WOULDBLOCK error. The current DTLS
183 * applications do not need asynchronous validation, but in the
184 * future we will need to add this.
185 */
186 sslBuffer buf = *origBuf;
187 SECStatus rv = SECSuccess;
188
189 PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
190 PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
191
192 while (buf.len > 0) {
193 PRUint8 type;
194 PRUint32 message_length;
195 PRUint16 message_seq;
196 PRUint32 fragment_offset;
197 PRUint32 fragment_length;
198 PRUint32 offset;
199
200 if (buf.len < 12) {
201 PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
202 rv = SECFailure;
203 break;
204 }
205
206 /* Parse the header */
207 type = buf.buf[0];
208 message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3];
209 message_seq = (buf.buf[4] << 8) | buf.buf[5];
210 fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8];
211 fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11];
212
213 #define MAX_HANDSHAKE_MSG_LEN 0x1ffff /* 128k - 1 */
214 if (message_length > MAX_HANDSHAKE_MSG_LEN) {
215 (void)ssl3_DecodeError(ss);
216 PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG);
217 return SECFailure;
218 }
219 #undef MAX_HANDSHAKE_MSG_LEN
220
221 buf.buf += 12;
222 buf.len -= 12;
223
224 /* This fragment must be complete */
225 if (buf.len < fragment_length) {
226 PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
227 rv = SECFailure;
228 break;
229 }
230
231 /* Sanity check the packet contents */
232 if ((fragment_length + fragment_offset) > message_length) {
233 PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
234 rv = SECFailure;
235 break;
236 }
237
238 /* There are three ways we could not be ready for this packet.
239 *
240 * 1. It's a partial next message.
241 * 2. It's a partial or complete message beyond the next
242 * 3. It's a message we've already seen
243 *
244 * If it's the complete next message we accept it right away.
245 * This is the common case for short messages
246 */
247 if ((message_seq == ss->ssl3.hs.recvMessageSeq)
248 && (fragment_offset == 0)
249 && (fragment_length == message_length)) {
250 /* Complete next message. Process immediately */
251 ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
252 ss->ssl3.hs.msg_len = message_length;
253
254 /* At this point we are advancing our state machine, so
255 * we can free our last flight of messages */
256 dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
257 ss->ssl3.hs.recvdHighWater = -1;
258 dtls_CancelTimer(ss);
259
260 /* Reset the timer to the initial value if the retry counter
261 * is 0, per Sec. 4.2.4.1 */
262 if (ss->ssl3.hs.rtRetries == 0) {
263 ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
264 }
265
266 rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len);
267 if (rv == SECFailure) {
268 /* Do not attempt to process rest of messages in this record */
269 break;
270 }
271 } else {
272 if (message_seq < ss->ssl3.hs.recvMessageSeq) {
273 /* Case 3: we do an immediate retransmit if we're
274 * in a waiting state*/
275 if (ss->ssl3.hs.rtTimerCb == NULL) {
276 /* Ignore */
277 } else if (ss->ssl3.hs.rtTimerCb ==
278 dtls_RetransmitTimerExpiredCb) {
279 SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected",
280 SSL_GETPID(), ss->fd));
281 /* Check to see if we retransmitted recently. If so,
282 * suppress the triggered retransmit. This avoids
283 * retransmit wars after packet loss.
284 * This is not in RFC 5346 but should be
285 */
286 if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
287 (ss->ssl3.hs.rtTimeoutMs / 4)) {
288 SSL_TRC(30,
289 ("%d: SSL3[%d]: Shortcutting retransmit timer",
290 SSL_GETPID(), ss->fd));
291
292 /* Cancel the timer and call the CB,
293 * which re-arms the timer */
294 dtls_CancelTimer(ss);
295 dtls_RetransmitTimerExpiredCb(ss);
296 rv = SECSuccess;
297 break;
298 } else {
299 SSL_TRC(30,
300 ("%d: SSL3[%d]: We just retransmitted. Ignoring.",
301 SSL_GETPID(), ss->fd));
302 rv = SECSuccess;
303 break;
304 }
305 } else if (ss->ssl3.hs.rtTimerCb == dtls_FinishedTimerCb) {
306 /* Retransmit the messages and re-arm the timer
307 * Note that we are not backing off the timer here.
308 * The spec isn't clear and my reasoning is that this
309 * may be a re-ordered packet rather than slowness,
310 * so let's be aggressive. */
311 dtls_CancelTimer(ss);
312 rv = dtls_TransmitMessageFlight(ss);
313 if (rv == SECSuccess) {
314 rv = dtls_StartTimer(ss, dtls_FinishedTimerCb);
315 }
316 if (rv != SECSuccess)
317 return rv;
318 break;
319 }
320 } else if (message_seq > ss->ssl3.hs.recvMessageSeq) {
321 /* Case 2
322 *
323 * Ignore this message. This means we don't handle out of
324 * order complete messages that well, but we're still
325 * compliant and this probably does not happen often
326 *
327 * XXX OK for now. Maybe do something smarter at some point?
328 */
329 } else {
330 /* Case 1
331 *
332 * Buffer the fragment for reassembly
333 */
334 /* Make room for the message */
335 if (ss->ssl3.hs.recvdHighWater == -1) {
336 PRUint32 map_length = OFFSET_BYTE(message_length) + 1;
337
338 rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length);
339 if (rv != SECSuccess)
340 break;
341 /* Make room for the fragment map */
342 rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments,
343 map_length);
344 if (rv != SECSuccess)
345 break;
346
347 /* Reset the reassembly map */
348 ss->ssl3.hs.recvdHighWater = 0;
349 PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0,
350 ss->ssl3.hs.recvdFragments.space);
351 ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
352 ss->ssl3.hs.msg_len = message_length;
353 }
354
355 /* If we have a message length mismatch, abandon the reassembly
356 * in progress and hope that the next retransmit will give us
357 * something sane
358 */
359 if (message_length != ss->ssl3.hs.msg_len) {
360 ss->ssl3.hs.recvdHighWater = -1;
361 PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
362 rv = SECFailure;
363 break;
364 }
365
366 /* Now copy this fragment into the buffer */
367 PORT_Assert((fragment_offset + fragment_length) <=
368 ss->ssl3.hs.msg_body.space);
369 PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset,
370 buf.buf, fragment_length);
371
372 /* This logic is a bit tricky. We have two values for
373 * reassembly state:
374 *
375 * - recvdHighWater contains the highest contiguous number of
376 * bytes received
377 * - recvdFragments contains a bitmask of packets received
378 * above recvdHighWater
379 *
380 * This avoids having to fill in the bitmask in the common
381 * case of adjacent fragments received in sequence
382 */
383 if (fragment_offset <= ss->ssl3.hs.recvdHighWater) {
384 /* Either this is the adjacent fragment or an overlapping
385 * fragment */
386 ss->ssl3.hs.recvdHighWater = fragment_offset +
387 fragment_length;
388 } else {
389 for (offset = fragment_offset;
390 offset < fragment_offset + fragment_length;
391 offset++) {
392 ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |=
393 OFFSET_MASK(offset);
394 }
395 }
396
397 /* Now figure out the new high water mark if appropriate */
398 for (offset = ss->ssl3.hs.recvdHighWater;
399 offset < ss->ssl3.hs.msg_len; offset++) {
400 /* Note that this loop is not efficient, since it counts
401 * bit by bit. If we have a lot of out-of-order packets,
402 * we should optimize this */
403 if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] &
404 OFFSET_MASK(offset)) {
405 ss->ssl3.hs.recvdHighWater++;
406 } else {
407 break;
408 }
409 }
410
411 /* If we have all the bytes, then we are good to go */
412 if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
413 ss->ssl3.hs.recvdHighWater = -1;
414
415 rv = ssl3_HandleHandshakeMessage(ss,
416 ss->ssl3.hs.msg_body.buf,
417 ss->ssl3.hs.msg_len);
418 if (rv == SECFailure)
419 break; /* Skip rest of record */
420
421 /* At this point we are advancing our state machine, so
422 * we can free our last flight of messages */
423 dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
424 dtls_CancelTimer(ss);
425
426 /* If there have been no retries this time, reset the
427 * timer value to the default per Section 4.2.4.1 */
428 if (ss->ssl3.hs.rtRetries == 0) {
429 ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
430 }
431 }
432 }
433 }
434
435 buf.buf += fragment_length;
436 buf.len -= fragment_length;
437 }
438
439 origBuf->len = 0; /* So ssl3_GatherAppDataRecord will keep looping. */
440
441 /* XXX OK for now. In future handle rv == SECWouldBlock safely in order
442 * to deal with asynchronous certificate verification */
443 return rv;
444 }
445
446 /* Enqueue a message (either handshake or CCS)
447 *
448 * Called from:
449 * dtls_StageHandshakeMessage()
450 * ssl3_SendChangeCipherSpecs()
451 */
dtls_QueueMessage(sslSocket * ss,SSL3ContentType type,const SSL3Opaque * pIn,PRInt32 nIn)452 SECStatus dtls_QueueMessage(sslSocket *ss, SSL3ContentType type,
453 const SSL3Opaque *pIn, PRInt32 nIn)
454 {
455 SECStatus rv = SECSuccess;
456 DTLSQueuedMessage *msg = NULL;
457
458 PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
459 PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
460
461 msg = dtls_AllocQueuedMessage(ss->ssl3.cwSpec->epoch, type, pIn, nIn);
462
463 if (!msg) {
464 PORT_SetError(SEC_ERROR_NO_MEMORY);
465 rv = SECFailure;
466 } else {
467 PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight);
468 }
469
470 return rv;
471 }
472
473 /* Add DTLS handshake message to the pending queue
474 * Empty the sendBuf buffer.
475 * This function returns SECSuccess or SECFailure, never SECWouldBlock.
476 * Always set sendBuf.len to 0, even when returning SECFailure.
477 *
478 * Called from:
479 * ssl3_AppendHandshakeHeader()
480 * dtls_FlushHandshake()
481 */
482 SECStatus
dtls_StageHandshakeMessage(sslSocket * ss)483 dtls_StageHandshakeMessage(sslSocket *ss)
484 {
485 SECStatus rv = SECSuccess;
486
487 PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
488 PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
489
490 /* This function is sometimes called when no data is actually to
491 * be staged, so just return SECSuccess. */
492 if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len)
493 return rv;
494
495 rv = dtls_QueueMessage(ss, content_handshake,
496 ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len);
497
498 /* Whether we succeeded or failed, toss the old handshake data. */
499 ss->sec.ci.sendBuf.len = 0;
500 return rv;
501 }
502
503 /* Enqueue the handshake message in sendBuf (if any) and then
504 * transmit the resulting flight of handshake messages.
505 *
506 * Called from:
507 * ssl3_FlushHandshake()
508 */
509 SECStatus
dtls_FlushHandshakeMessages(sslSocket * ss,PRInt32 flags)510 dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags)
511 {
512 SECStatus rv = SECSuccess;
513
514 PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
515 PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
516
517 rv = dtls_StageHandshakeMessage(ss);
518 if (rv != SECSuccess)
519 return rv;
520
521 if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) {
522 rv = dtls_TransmitMessageFlight(ss);
523 if (rv != SECSuccess)
524 return rv;
525
526 if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) {
527 ss->ssl3.hs.rtRetries = 0;
528 rv = dtls_StartTimer(ss, dtls_RetransmitTimerExpiredCb);
529 }
530 }
531
532 return rv;
533 }
534
535 /* The callback for when the retransmit timer expires
536 *
537 * Called from:
538 * dtls_CheckTimer()
539 * dtls_HandleHandshake()
540 */
541 static void
dtls_RetransmitTimerExpiredCb(sslSocket * ss)542 dtls_RetransmitTimerExpiredCb(sslSocket *ss)
543 {
544 SECStatus rv = SECFailure;
545
546 ss->ssl3.hs.rtRetries++;
547
548 if (!(ss->ssl3.hs.rtRetries % 3)) {
549 /* If one of the messages was potentially greater than > MTU,
550 * then downgrade. Do this every time we have retransmitted a
551 * message twice, per RFC 6347 Sec. 4.1.1 */
552 dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1);
553 }
554
555 rv = dtls_TransmitMessageFlight(ss);
556 if (rv == SECSuccess) {
557
558 /* Re-arm the timer */
559 rv = dtls_RestartTimer(ss, PR_TRUE, dtls_RetransmitTimerExpiredCb);
560 }
561
562 if (rv == SECFailure) {
563 /* XXX OK for now. In future maybe signal the stack that we couldn't
564 * transmit. For now, let the read handle any real network errors */
565 }
566 }
567
568 /* Transmit a flight of handshake messages, stuffing them
569 * into as few records as seems reasonable
570 *
571 * Called from:
572 * dtls_FlushHandshake()
573 * dtls_RetransmitTimerExpiredCb()
574 */
575 static SECStatus
dtls_TransmitMessageFlight(sslSocket * ss)576 dtls_TransmitMessageFlight(sslSocket *ss)
577 {
578 SECStatus rv = SECSuccess;
579 PRCList *msg_p;
580 PRUint16 room_left = ss->ssl3.mtu;
581 PRInt32 sent;
582
583 ssl_GetXmitBufLock(ss);
584 ssl_GetSpecReadLock(ss);
585
586 /* DTLS does not buffer its handshake messages in
587 * ss->pendingBuf, but rather in the lastMessageFlight
588 * structure. This is just a sanity check that
589 * some programming error hasn't inadvertantly
590 * stuffed something in ss->pendingBuf
591 */
592 PORT_Assert(!ss->pendingBuf.len);
593 for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight);
594 msg_p != &ss->ssl3.hs.lastMessageFlight;
595 msg_p = PR_NEXT_LINK(msg_p)) {
596 DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p;
597
598 /* The logic here is:
599 *
600 * 1. If this is a message that will not fit into the remaining
601 * space, then flush.
602 * 2. If the message will now fit into the remaining space,
603 * encrypt, buffer, and loop.
604 * 3. If the message will not fit, then fragment.
605 *
606 * At the end of the function, flush.
607 */
608 if ((msg->len + SSL3_BUFFER_FUDGE) > room_left) {
609 /* The message will not fit into the remaining space, so flush */
610 rv = dtls_SendSavedWriteData(ss);
611 if (rv != SECSuccess)
612 break;
613
614 room_left = ss->ssl3.mtu;
615 }
616
617 if ((msg->len + SSL3_BUFFER_FUDGE) <= room_left) {
618 /* The message will fit, so encrypt and then continue with the
619 * next packet */
620 sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
621 msg->data, msg->len,
622 ssl_SEND_FLAG_FORCE_INTO_BUFFER |
623 ssl_SEND_FLAG_USE_EPOCH);
624 if (sent != msg->len) {
625 rv = SECFailure;
626 if (sent != -1) {
627 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
628 }
629 break;
630 }
631
632 room_left = ss->ssl3.mtu - ss->pendingBuf.len;
633 } else {
634 /* The message will not fit, so fragment.
635 *
636 * XXX OK for now. Arrange to coalesce the last fragment
637 * of this message with the next message if possible.
638 * That would be more efficient.
639 */
640 PRUint32 fragment_offset = 0;
641 unsigned char fragment[DTLS_MAX_MTU]; /* >= than largest
642 * plausible MTU */
643
644 /* Assert that we have already flushed */
645 PORT_Assert(room_left == ss->ssl3.mtu);
646
647 /* Case 3: We now need to fragment this message
648 * DTLS only supports fragmenting handshaking messages */
649 PORT_Assert(msg->type == content_handshake);
650
651 /* The headers consume 12 bytes so the smalles possible
652 * message (i.e., an empty one) is 12 bytes
653 */
654 PORT_Assert(msg->len >= 12);
655
656 while ((fragment_offset + 12) < msg->len) {
657 PRUint32 fragment_len;
658 const unsigned char *content = msg->data + 12;
659 PRUint32 content_len = msg->len - 12;
660
661 /* The reason we use 8 here is that that's the length of
662 * the new DTLS data that we add to the header */
663 fragment_len = PR_MIN(room_left - (SSL3_BUFFER_FUDGE + 8),
664 content_len - fragment_offset);
665 PORT_Assert(fragment_len < DTLS_MAX_MTU - 12);
666 /* Make totally sure that we are within the buffer.
667 * Note that the only way that fragment len could get
668 * adjusted here is if
669 *
670 * (a) we are in release mode so the PORT_Assert is compiled out
671 * (b) either the MTU table is inconsistent with DTLS_MAX_MTU
672 * or ss->ssl3.mtu has become corrupt.
673 */
674 fragment_len = PR_MIN(fragment_len, DTLS_MAX_MTU - 12);
675
676 /* Construct an appropriate-sized fragment */
677 /* Type, length, sequence */
678 PORT_Memcpy(fragment, msg->data, 6);
679
680 /* Offset */
681 fragment[6] = (fragment_offset >> 16) & 0xff;
682 fragment[7] = (fragment_offset >> 8) & 0xff;
683 fragment[8] = (fragment_offset) & 0xff;
684
685 /* Fragment length */
686 fragment[9] = (fragment_len >> 16) & 0xff;
687 fragment[10] = (fragment_len >> 8) & 0xff;
688 fragment[11] = (fragment_len) & 0xff;
689
690 PORT_Memcpy(fragment + 12, content + fragment_offset,
691 fragment_len);
692
693 /*
694 * Send the record. We do this in two stages
695 * 1. Encrypt
696 */
697 sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
698 fragment, fragment_len + 12,
699 ssl_SEND_FLAG_FORCE_INTO_BUFFER |
700 ssl_SEND_FLAG_USE_EPOCH);
701 if (sent != (fragment_len + 12)) {
702 rv = SECFailure;
703 if (sent != -1) {
704 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
705 }
706 break;
707 }
708
709 /* 2. Flush */
710 rv = dtls_SendSavedWriteData(ss);
711 if (rv != SECSuccess)
712 break;
713
714 fragment_offset += fragment_len;
715 }
716 }
717 }
718
719 /* Finally, we need to flush */
720 if (rv == SECSuccess)
721 rv = dtls_SendSavedWriteData(ss);
722
723 /* Give up the locks */
724 ssl_ReleaseSpecReadLock(ss);
725 ssl_ReleaseXmitBufLock(ss);
726
727 return rv;
728 }
729
730 /* Flush the data in the pendingBuf and update the max message sent
731 * so we can adjust the MTU estimate if we need to.
732 * Wrapper for ssl_SendSavedWriteData.
733 *
734 * Called from dtls_TransmitMessageFlight()
735 */
736 static
dtls_SendSavedWriteData(sslSocket * ss)737 SECStatus dtls_SendSavedWriteData(sslSocket *ss)
738 {
739 PRInt32 sent;
740
741 sent = ssl_SendSavedWriteData(ss);
742 if (sent < 0)
743 return SECFailure;
744
745 /* We should always have complete writes b/c datagram sockets
746 * don't really block */
747 if (ss->pendingBuf.len > 0) {
748 ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE);
749 return SECFailure;
750 }
751
752 /* Update the largest message sent so we can adjust the MTU
753 * estimate if necessary */
754 if (sent > ss->ssl3.hs.maxMessageSent)
755 ss->ssl3.hs.maxMessageSent = sent;
756
757 return SECSuccess;
758 }
759
760 /* Compress, MAC, encrypt a DTLS record. Allows specification of
761 * the epoch using epoch value. If use_epoch is PR_TRUE then
762 * we use the provided epoch. If use_epoch is PR_FALSE then
763 * whatever the current value is in effect is used.
764 *
765 * Called from ssl3_SendRecord()
766 */
767 SECStatus
dtls_CompressMACEncryptRecord(sslSocket * ss,DTLSEpoch epoch,PRBool use_epoch,SSL3ContentType type,const SSL3Opaque * pIn,PRUint32 contentLen,sslBuffer * wrBuf)768 dtls_CompressMACEncryptRecord(sslSocket * ss,
769 DTLSEpoch epoch,
770 PRBool use_epoch,
771 SSL3ContentType type,
772 const SSL3Opaque * pIn,
773 PRUint32 contentLen,
774 sslBuffer * wrBuf)
775 {
776 SECStatus rv = SECFailure;
777 ssl3CipherSpec * cwSpec;
778
779 ssl_GetSpecReadLock(ss); /********************************/
780
781 /* The reason for this switch-hitting code is that we might have
782 * a flight of records spanning an epoch boundary, e.g.,
783 *
784 * ClientKeyExchange (epoch = 0)
785 * ChangeCipherSpec (epoch = 0)
786 * Finished (epoch = 1)
787 *
788 * Thus, each record needs a different cipher spec. The information
789 * about which epoch to use is carried with the record.
790 */
791 if (use_epoch) {
792 if (ss->ssl3.cwSpec->epoch == epoch)
793 cwSpec = ss->ssl3.cwSpec;
794 else if (ss->ssl3.pwSpec->epoch == epoch)
795 cwSpec = ss->ssl3.pwSpec;
796 else
797 cwSpec = NULL;
798 } else {
799 cwSpec = ss->ssl3.cwSpec;
800 }
801
802 if (cwSpec) {
803 rv = ssl3_CompressMACEncryptRecord(cwSpec, ss->sec.isServer, PR_TRUE,
804 PR_FALSE, type, pIn, contentLen,
805 wrBuf);
806 } else {
807 PR_NOT_REACHED("Couldn't find a cipher spec matching epoch");
808 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
809 }
810 ssl_ReleaseSpecReadLock(ss); /************************************/
811
812 return rv;
813 }
814
815 /* Start a timer
816 *
817 * Called from:
818 * dtls_HandleHandshake()
819 * dtls_FlushHAndshake()
820 * dtls_RestartTimer()
821 */
822 SECStatus
dtls_StartTimer(sslSocket * ss,DTLSTimerCb cb)823 dtls_StartTimer(sslSocket *ss, DTLSTimerCb cb)
824 {
825 PORT_Assert(ss->ssl3.hs.rtTimerCb == NULL);
826
827 ss->ssl3.hs.rtTimerStarted = PR_IntervalNow();
828 ss->ssl3.hs.rtTimerCb = cb;
829
830 return SECSuccess;
831 }
832
833 /* Restart a timer with optional backoff
834 *
835 * Called from dtls_RetransmitTimerExpiredCb()
836 */
837 SECStatus
dtls_RestartTimer(sslSocket * ss,PRBool backoff,DTLSTimerCb cb)838 dtls_RestartTimer(sslSocket *ss, PRBool backoff, DTLSTimerCb cb)
839 {
840 if (backoff) {
841 ss->ssl3.hs.rtTimeoutMs *= 2;
842 if (ss->ssl3.hs.rtTimeoutMs > MAX_DTLS_TIMEOUT_MS)
843 ss->ssl3.hs.rtTimeoutMs = MAX_DTLS_TIMEOUT_MS;
844 }
845
846 return dtls_StartTimer(ss, cb);
847 }
848
849 /* Cancel a pending timer
850 *
851 * Called from:
852 * dtls_HandleHandshake()
853 * dtls_CheckTimer()
854 */
855 void
dtls_CancelTimer(sslSocket * ss)856 dtls_CancelTimer(sslSocket *ss)
857 {
858 PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
859
860 ss->ssl3.hs.rtTimerCb = NULL;
861 }
862
863 /* Check the pending timer and fire the callback if it expired
864 *
865 * Called from ssl3_GatherCompleteHandshake()
866 */
867 void
dtls_CheckTimer(sslSocket * ss)868 dtls_CheckTimer(sslSocket *ss)
869 {
870 if (!ss->ssl3.hs.rtTimerCb)
871 return;
872
873 if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
874 PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs)) {
875 /* Timer has expired */
876 DTLSTimerCb cb = ss->ssl3.hs.rtTimerCb;
877
878 /* Cancel the timer so that we can call the CB safely */
879 dtls_CancelTimer(ss);
880
881 /* Now call the CB */
882 cb(ss);
883 }
884 }
885
886 /* The callback to fire when the holddown timer for the Finished
887 * message expires and we can delete it
888 *
889 * Called from dtls_CheckTimer()
890 */
891 void
dtls_FinishedTimerCb(sslSocket * ss)892 dtls_FinishedTimerCb(sslSocket *ss)
893 {
894 ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
895 }
896
897 /* Cancel the Finished hold-down timer and destroy the
898 * pending cipher spec. Note that this means that
899 * successive rehandshakes will fail if the Finished is
900 * lost.
901 *
902 * XXX OK for now. Figure out how to handle the combination
903 * of Finished lost and rehandshake
904 */
905 void
dtls_RehandshakeCleanup(sslSocket * ss)906 dtls_RehandshakeCleanup(sslSocket *ss)
907 {
908 dtls_CancelTimer(ss);
909 ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
910 ss->ssl3.hs.sendMessageSeq = 0;
911 ss->ssl3.hs.recvMessageSeq = 0;
912 }
913
914 /* Set the MTU to the next step less than or equal to the
915 * advertised value. Also used to downgrade the MTU by
916 * doing dtls_SetMTU(ss, biggest packet set).
917 *
918 * Passing 0 means set this to the largest MTU known
919 * (effectively resetting the PMTU backoff value).
920 *
921 * Called by:
922 * ssl3_InitState()
923 * dtls_RetransmitTimerExpiredCb()
924 */
925 void
dtls_SetMTU(sslSocket * ss,PRUint16 advertised)926 dtls_SetMTU(sslSocket *ss, PRUint16 advertised)
927 {
928 int i;
929
930 if (advertised == 0) {
931 ss->ssl3.mtu = COMMON_MTU_VALUES[0];
932 SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
933 return;
934 }
935
936 for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) {
937 if (COMMON_MTU_VALUES[i] <= advertised) {
938 ss->ssl3.mtu = COMMON_MTU_VALUES[i];
939 SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
940 return;
941 }
942 }
943
944 /* Fallback */
945 ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES)-1];
946 SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
947 }
948
949 /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a
950 * DTLS hello_verify_request
951 * Caller must hold Handshake and RecvBuf locks.
952 */
953 SECStatus
dtls_HandleHelloVerifyRequest(sslSocket * ss,SSL3Opaque * b,PRUint32 length)954 dtls_HandleHelloVerifyRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
955 {
956 int errCode = SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST;
957 SECStatus rv;
958 PRInt32 temp;
959 SECItem cookie = {siBuffer, NULL, 0};
960 SSL3AlertDescription desc = illegal_parameter;
961
962 SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake",
963 SSL_GETPID(), ss->fd));
964 PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
965 PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
966
967 if (ss->ssl3.hs.ws != wait_server_hello) {
968 errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST;
969 desc = unexpected_message;
970 goto alert_loser;
971 }
972
973 /* The version */
974 temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
975 if (temp < 0) {
976 goto loser; /* alert has been sent */
977 }
978
979 if (temp != SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) {
980 /* Note: this will need adjustment for DTLS 1.2 per Section 4.2.1 */
981 goto alert_loser;
982 }
983
984 /* The cookie */
985 rv = ssl3_ConsumeHandshakeVariable(ss, &cookie, 1, &b, &length);
986 if (rv != SECSuccess) {
987 goto loser; /* alert has been sent */
988 }
989 if (cookie.len > DTLS_COOKIE_BYTES) {
990 desc = decode_error;
991 goto alert_loser; /* malformed. */
992 }
993
994 PORT_Memcpy(ss->ssl3.hs.cookie, cookie.data, cookie.len);
995 ss->ssl3.hs.cookieLen = cookie.len;
996
997
998 ssl_GetXmitBufLock(ss); /*******************************/
999
1000 /* Now re-send the client hello */
1001 rv = ssl3_SendClientHello(ss, PR_TRUE);
1002
1003 ssl_ReleaseXmitBufLock(ss); /*******************************/
1004
1005 if (rv == SECSuccess)
1006 return rv;
1007
1008 alert_loser:
1009 (void)SSL3_SendAlert(ss, alert_fatal, desc);
1010
1011 loser:
1012 errCode = ssl_MapLowLevelError(errCode);
1013 return SECFailure;
1014 }
1015
1016 /* Initialize the DTLS anti-replay window
1017 *
1018 * Called from:
1019 * ssl3_SetupPendingCipherSpec()
1020 * ssl3_InitCipherSpec()
1021 */
1022 void
dtls_InitRecvdRecords(DTLSRecvdRecords * records)1023 dtls_InitRecvdRecords(DTLSRecvdRecords *records)
1024 {
1025 PORT_Memset(records->data, 0, sizeof(records->data));
1026 records->left = 0;
1027 records->right = DTLS_RECVD_RECORDS_WINDOW - 1;
1028 }
1029
1030 /*
1031 * Has this DTLS record been received? Return values are:
1032 * -1 -- out of range to the left
1033 * 0 -- not received yet
1034 * 1 -- replay
1035 *
1036 * Called from: dtls_HandleRecord()
1037 */
1038 int
dtls_RecordGetRecvd(DTLSRecvdRecords * records,PRUint64 seq)1039 dtls_RecordGetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1040 {
1041 PRUint64 offset;
1042
1043 /* Out of range to the left */
1044 if (seq < records->left) {
1045 return -1;
1046 }
1047
1048 /* Out of range to the right; since we advance the window on
1049 * receipt, that means that this packet has not been received
1050 * yet */
1051 if (seq > records->right)
1052 return 0;
1053
1054 offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1055
1056 return !!(records->data[offset / 8] & (1 << (offset % 8)));
1057 }
1058
1059 /* Update the DTLS anti-replay window
1060 *
1061 * Called from ssl3_HandleRecord()
1062 */
1063 void
dtls_RecordSetRecvd(DTLSRecvdRecords * records,PRUint64 seq)1064 dtls_RecordSetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1065 {
1066 PRUint64 offset;
1067
1068 if (seq < records->left)
1069 return;
1070
1071 if (seq > records->right) {
1072 PRUint64 new_left;
1073 PRUint64 new_right;
1074 PRUint64 right;
1075
1076 /* Slide to the right; this is the tricky part
1077 *
1078 * 1. new_top is set to have room for seq, on the
1079 * next byte boundary by setting the right 8
1080 * bits of seq
1081 * 2. new_left is set to compensate.
1082 * 3. Zero all bits between top and new_top. Since
1083 * this is a ring, this zeroes everything as-yet
1084 * unseen. Because we always operate on byte
1085 * boundaries, we can zero one byte at a time
1086 */
1087 new_right = seq | 0x07;
1088 new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1;
1089
1090 for (right = records->right + 8; right <= new_right; right += 8) {
1091 offset = right % DTLS_RECVD_RECORDS_WINDOW;
1092 records->data[offset / 8] = 0;
1093 }
1094
1095 records->right = new_right;
1096 records->left = new_left;
1097 }
1098
1099 offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1100
1101 records->data[offset / 8] |= (1 << (offset % 8));
1102 }
1103
1104 SECStatus
DTLS_GetHandshakeTimeout(PRFileDesc * socket,PRIntervalTime * timeout)1105 DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout)
1106 {
1107 sslSocket * ss = NULL;
1108 PRIntervalTime elapsed;
1109 PRIntervalTime desired;
1110
1111 ss = ssl_FindSocket(socket);
1112
1113 if (!ss)
1114 return SECFailure;
1115
1116 if (!IS_DTLS(ss))
1117 return SECFailure;
1118
1119 if (!ss->ssl3.hs.rtTimerCb)
1120 return SECFailure;
1121
1122 elapsed = PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted;
1123 desired = PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs);
1124 if (elapsed > desired) {
1125 /* Timer expired */
1126 *timeout = PR_INTERVAL_NO_WAIT;
1127 } else {
1128 *timeout = desired - elapsed;
1129 }
1130
1131 return SECSuccess;
1132 }
1133