• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 /*
6  * DTLS Protocol
7  */
8 
9 #include "ssl.h"
10 #include "sslimpl.h"
11 #include "sslproto.h"
12 
13 #ifndef PR_ARRAY_SIZE
14 #define PR_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0]))
15 #endif
16 
17 static SECStatus dtls_TransmitMessageFlight(sslSocket *ss);
18 static void dtls_RetransmitTimerExpiredCb(sslSocket *ss);
19 static SECStatus dtls_SendSavedWriteData(sslSocket *ss);
20 
21 /* -28 adjusts for the IP/UDP header */
22 static const PRUint16 COMMON_MTU_VALUES[] = {
23     1500 - 28,  /* Ethernet MTU */
24     1280 - 28,  /* IPv6 minimum MTU */
25     576 - 28,   /* Common assumption */
26     256 - 28    /* We're in serious trouble now */
27 };
28 
29 #define DTLS_COOKIE_BYTES 32
30 
31 /* List copied from ssl3con.c:cipherSuites */
32 static const ssl3CipherSuite nonDTLSSuites[] = {
33 #ifdef NSS_ENABLE_ECC
34     TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
35     TLS_ECDHE_RSA_WITH_RC4_128_SHA,
36 #endif  /* NSS_ENABLE_ECC */
37     TLS_DHE_DSS_WITH_RC4_128_SHA,
38 #ifdef NSS_ENABLE_ECC
39     TLS_ECDH_RSA_WITH_RC4_128_SHA,
40     TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
41 #endif  /* NSS_ENABLE_ECC */
42     SSL_RSA_WITH_RC4_128_MD5,
43     SSL_RSA_WITH_RC4_128_SHA,
44     TLS_RSA_EXPORT1024_WITH_RC4_56_SHA,
45     SSL_RSA_EXPORT_WITH_RC4_40_MD5,
46     0 /* End of list marker */
47 };
48 
49 /* Map back and forth between TLS and DTLS versions in wire format.
50  * Mapping table is:
51  *
52  * TLS             DTLS
53  * 1.1 (0302)      1.0 (feff)
54  */
55 SSL3ProtocolVersion
dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)56 dtls_TLSVersionToDTLSVersion(SSL3ProtocolVersion tlsv)
57 {
58     /* Anything other than TLS 1.1 is an error, so return
59      * the invalid version ffff. */
60     if (tlsv != SSL_LIBRARY_VERSION_TLS_1_1)
61 	return 0xffff;
62 
63     return SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
64 }
65 
66 /* Map known DTLS versions to known TLS versions.
67  * - Invalid versions (< 1.0) return a version of 0
68  * - Versions > known return a version one higher than we know of
69  * to accomodate a theoretically newer version */
70 SSL3ProtocolVersion
dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)71 dtls_DTLSVersionToTLSVersion(SSL3ProtocolVersion dtlsv)
72 {
73     if (MSB(dtlsv) == 0xff) {
74 	return 0;
75     }
76 
77     if (dtlsv == SSL_LIBRARY_VERSION_DTLS_1_0_WIRE)
78 	return SSL_LIBRARY_VERSION_TLS_1_1;
79 
80     /* Return a fictional higher version than we know of */
81     return SSL_LIBRARY_VERSION_TLS_1_1 + 1;
82 }
83 
84 /* On this socket, Disable non-DTLS cipher suites in the argument's list */
85 SECStatus
ssl3_DisableNonDTLSSuites(sslSocket * ss)86 ssl3_DisableNonDTLSSuites(sslSocket * ss)
87 {
88     const ssl3CipherSuite * suite;
89 
90     for (suite = nonDTLSSuites; *suite; ++suite) {
91 	SECStatus rv = ssl3_CipherPrefSet(ss, *suite, PR_FALSE);
92 
93 	PORT_Assert(rv == SECSuccess); /* else is coding error */
94     }
95     return SECSuccess;
96 }
97 
98 /* Allocate a DTLSQueuedMessage.
99  *
100  * Called from dtls_QueueMessage()
101  */
102 static DTLSQueuedMessage *
dtls_AllocQueuedMessage(PRUint16 epoch,SSL3ContentType type,const unsigned char * data,PRUint32 len)103 dtls_AllocQueuedMessage(PRUint16 epoch, SSL3ContentType type,
104 			const unsigned char *data, PRUint32 len)
105 {
106     DTLSQueuedMessage *msg = NULL;
107 
108     msg = PORT_ZAlloc(sizeof(DTLSQueuedMessage));
109     if (!msg)
110 	return NULL;
111 
112     msg->data = PORT_Alloc(len);
113     if (!msg->data) {
114 	PORT_Free(msg);
115         return NULL;
116     }
117     PORT_Memcpy(msg->data, data, len);
118 
119     msg->len = len;
120     msg->epoch = epoch;
121     msg->type = type;
122 
123     return msg;
124 }
125 
126 /*
127  * Free a handshake message
128  *
129  * Called from dtls_FreeHandshakeMessages()
130  */
131 static void
dtls_FreeHandshakeMessage(DTLSQueuedMessage * msg)132 dtls_FreeHandshakeMessage(DTLSQueuedMessage *msg)
133 {
134     if (!msg)
135 	return;
136 
137     PORT_ZFree(msg->data, msg->len);
138     PORT_Free(msg);
139 }
140 
141 /*
142  * Free a list of handshake messages
143  *
144  * Called from:
145  *              dtls_HandleHandshake()
146  *              ssl3_DestroySSL3Info()
147  */
148 void
dtls_FreeHandshakeMessages(PRCList * list)149 dtls_FreeHandshakeMessages(PRCList *list)
150 {
151     PRCList *cur_p;
152 
153     while (!PR_CLIST_IS_EMPTY(list)) {
154 	cur_p = PR_LIST_TAIL(list);
155 	PR_REMOVE_LINK(cur_p);
156 	dtls_FreeHandshakeMessage((DTLSQueuedMessage *)cur_p);
157     }
158 }
159 
160 /* Called only from ssl3_HandleRecord, for each (deciphered) DTLS record.
161  * origBuf is the decrypted ssl record content and is expected to contain
162  * complete handshake records
163  * Caller must hold the handshake and RecvBuf locks.
164  *
165  * Note that this code uses msg_len for two purposes:
166  *
167  * (1) To pass the length to ssl3_HandleHandshakeMessage()
168  * (2) To carry the length of a message currently being reassembled
169  *
170  * However, unlike ssl3_HandleHandshake(), it is not used to carry
171  * the state of reassembly (i.e., whether one is in progress). That
172  * is carried in recvdHighWater and recvdFragments.
173  */
174 #define OFFSET_BYTE(o) (o/8)
175 #define OFFSET_MASK(o) (1 << (o%8))
176 
177 SECStatus
dtls_HandleHandshake(sslSocket * ss,sslBuffer * origBuf)178 dtls_HandleHandshake(sslSocket *ss, sslBuffer *origBuf)
179 {
180     /* XXX OK for now.
181      * This doesn't work properly with asynchronous certificate validation.
182      * because that returns a WOULDBLOCK error. The current DTLS
183      * applications do not need asynchronous validation, but in the
184      * future we will need to add this.
185      */
186     sslBuffer buf = *origBuf;
187     SECStatus rv = SECSuccess;
188 
189     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
190     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
191 
192     while (buf.len > 0) {
193         PRUint8 type;
194         PRUint32 message_length;
195         PRUint16 message_seq;
196         PRUint32 fragment_offset;
197         PRUint32 fragment_length;
198         PRUint32 offset;
199 
200         if (buf.len < 12) {
201             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
202             rv = SECFailure;
203             break;
204         }
205 
206         /* Parse the header */
207 	type = buf.buf[0];
208         message_length = (buf.buf[1] << 16) | (buf.buf[2] << 8) | buf.buf[3];
209         message_seq = (buf.buf[4] << 8) | buf.buf[5];
210         fragment_offset = (buf.buf[6] << 16) | (buf.buf[7] << 8) | buf.buf[8];
211         fragment_length = (buf.buf[9] << 16) | (buf.buf[10] << 8) | buf.buf[11];
212 
213 #define MAX_HANDSHAKE_MSG_LEN 0x1ffff	/* 128k - 1 */
214 	if (message_length > MAX_HANDSHAKE_MSG_LEN) {
215 	    (void)ssl3_DecodeError(ss);
216 	    PORT_SetError(SSL_ERROR_RX_RECORD_TOO_LONG);
217 	    return SECFailure;
218 	}
219 #undef MAX_HANDSHAKE_MSG_LEN
220 
221         buf.buf += 12;
222         buf.len -= 12;
223 
224         /* This fragment must be complete */
225         if (buf.len < fragment_length) {
226             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
227             rv = SECFailure;
228             break;
229         }
230 
231         /* Sanity check the packet contents */
232 	if ((fragment_length + fragment_offset) > message_length) {
233             PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
234             rv = SECFailure;
235             break;
236         }
237 
238         /* There are three ways we could not be ready for this packet.
239          *
240          * 1. It's a partial next message.
241          * 2. It's a partial or complete message beyond the next
242          * 3. It's a message we've already seen
243          *
244          * If it's the complete next message we accept it right away.
245          * This is the common case for short messages
246          */
247         if ((message_seq == ss->ssl3.hs.recvMessageSeq)
248 	    && (fragment_offset == 0)
249 	    && (fragment_length == message_length)) {
250             /* Complete next message. Process immediately */
251             ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
252             ss->ssl3.hs.msg_len = message_length;
253 
254             /* At this point we are advancing our state machine, so
255              * we can free our last flight of messages */
256             dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
257 	    ss->ssl3.hs.recvdHighWater = -1;
258 	    dtls_CancelTimer(ss);
259 
260 	    /* Reset the timer to the initial value if the retry counter
261 	     * is 0, per Sec. 4.2.4.1 */
262 	    if (ss->ssl3.hs.rtRetries == 0) {
263 		ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
264 	    }
265 
266             rv = ssl3_HandleHandshakeMessage(ss, buf.buf, ss->ssl3.hs.msg_len);
267             if (rv == SECFailure) {
268                 /* Do not attempt to process rest of messages in this record */
269                 break;
270             }
271         } else {
272 	    if (message_seq < ss->ssl3.hs.recvMessageSeq) {
273 		/* Case 3: we do an immediate retransmit if we're
274 		 * in a waiting state*/
275 		if (ss->ssl3.hs.rtTimerCb == NULL) {
276 		    /* Ignore */
277 		} else if (ss->ssl3.hs.rtTimerCb ==
278 			 dtls_RetransmitTimerExpiredCb) {
279 		    SSL_TRC(30, ("%d: SSL3[%d]: Retransmit detected",
280 				 SSL_GETPID(), ss->fd));
281 		    /* Check to see if we retransmitted recently. If so,
282 		     * suppress the triggered retransmit. This avoids
283 		     * retransmit wars after packet loss.
284 		     * This is not in RFC 5346 but should be
285 		     */
286 		    if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
287 			(ss->ssl3.hs.rtTimeoutMs / 4)) {
288 			    SSL_TRC(30,
289 			    ("%d: SSL3[%d]: Shortcutting retransmit timer",
290                             SSL_GETPID(), ss->fd));
291 
292 			    /* Cancel the timer and call the CB,
293 			     * which re-arms the timer */
294 			    dtls_CancelTimer(ss);
295 			    dtls_RetransmitTimerExpiredCb(ss);
296 			    rv = SECSuccess;
297 			    break;
298 			} else {
299 			    SSL_TRC(30,
300 			    ("%d: SSL3[%d]: We just retransmitted. Ignoring.",
301                             SSL_GETPID(), ss->fd));
302 			    rv = SECSuccess;
303 			    break;
304 			}
305 		} else if (ss->ssl3.hs.rtTimerCb == dtls_FinishedTimerCb) {
306 		    /* Retransmit the messages and re-arm the timer
307 		     * Note that we are not backing off the timer here.
308 		     * The spec isn't clear and my reasoning is that this
309 		     * may be a re-ordered packet rather than slowness,
310 		     * so let's be aggressive. */
311 		    dtls_CancelTimer(ss);
312 		    rv = dtls_TransmitMessageFlight(ss);
313 		    if (rv == SECSuccess) {
314 			rv = dtls_StartTimer(ss, dtls_FinishedTimerCb);
315 		    }
316 		    if (rv != SECSuccess)
317 			return rv;
318 		    break;
319 		}
320 	    } else if (message_seq > ss->ssl3.hs.recvMessageSeq) {
321 		/* Case 2
322                  *
323 		 * Ignore this message. This means we don't handle out of
324 		 * order complete messages that well, but we're still
325 		 * compliant and this probably does not happen often
326                  *
327 		 * XXX OK for now. Maybe do something smarter at some point?
328 		 */
329 	    } else {
330 		/* Case 1
331                  *
332 		 * Buffer the fragment for reassembly
333 		 */
334                 /* Make room for the message */
335                 if (ss->ssl3.hs.recvdHighWater == -1) {
336                     PRUint32 map_length = OFFSET_BYTE(message_length) + 1;
337 
338                     rv = sslBuffer_Grow(&ss->ssl3.hs.msg_body, message_length);
339                     if (rv != SECSuccess)
340                         break;
341                     /* Make room for the fragment map */
342                     rv = sslBuffer_Grow(&ss->ssl3.hs.recvdFragments,
343                                         map_length);
344                     if (rv != SECSuccess)
345                         break;
346 
347                     /* Reset the reassembly map */
348                     ss->ssl3.hs.recvdHighWater = 0;
349                     PORT_Memset(ss->ssl3.hs.recvdFragments.buf, 0,
350 				ss->ssl3.hs.recvdFragments.space);
351 		    ss->ssl3.hs.msg_type = (SSL3HandshakeType)type;
352                     ss->ssl3.hs.msg_len = message_length;
353                 }
354 
355                 /* If we have a message length mismatch, abandon the reassembly
356                  * in progress and hope that the next retransmit will give us
357                  * something sane
358                  */
359                 if (message_length != ss->ssl3.hs.msg_len) {
360                     ss->ssl3.hs.recvdHighWater = -1;
361                     PORT_SetError(SSL_ERROR_RX_MALFORMED_HANDSHAKE);
362                     rv = SECFailure;
363                     break;
364                 }
365 
366                 /* Now copy this fragment into the buffer */
367                 PORT_Assert((fragment_offset + fragment_length) <=
368                             ss->ssl3.hs.msg_body.space);
369                 PORT_Memcpy(ss->ssl3.hs.msg_body.buf + fragment_offset,
370                             buf.buf, fragment_length);
371 
372                 /* This logic is a bit tricky. We have two values for
373                  * reassembly state:
374                  *
375                  * - recvdHighWater contains the highest contiguous number of
376                  *   bytes received
377                  * - recvdFragments contains a bitmask of packets received
378                  *   above recvdHighWater
379                  *
380                  * This avoids having to fill in the bitmask in the common
381                  * case of adjacent fragments received in sequence
382                  */
383                 if (fragment_offset <= ss->ssl3.hs.recvdHighWater) {
384 		    /* Either this is the adjacent fragment or an overlapping
385                      * fragment */
386                     ss->ssl3.hs.recvdHighWater = fragment_offset +
387                                                  fragment_length;
388                 } else {
389                     for (offset = fragment_offset;
390                          offset < fragment_offset + fragment_length;
391                          offset++) {
392                         ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] |=
393                             OFFSET_MASK(offset);
394                     }
395                 }
396 
397                 /* Now figure out the new high water mark if appropriate */
398                 for (offset = ss->ssl3.hs.recvdHighWater;
399                      offset < ss->ssl3.hs.msg_len; offset++) {
400 		    /* Note that this loop is not efficient, since it counts
401 		     * bit by bit. If we have a lot of out-of-order packets,
402 		     * we should optimize this */
403                     if (ss->ssl3.hs.recvdFragments.buf[OFFSET_BYTE(offset)] &
404                         OFFSET_MASK(offset)) {
405                         ss->ssl3.hs.recvdHighWater++;
406                     } else {
407                         break;
408                     }
409                 }
410 
411                 /* If we have all the bytes, then we are good to go */
412                 if (ss->ssl3.hs.recvdHighWater == ss->ssl3.hs.msg_len) {
413                     ss->ssl3.hs.recvdHighWater = -1;
414 
415                     rv = ssl3_HandleHandshakeMessage(ss,
416                                                      ss->ssl3.hs.msg_body.buf,
417                                                      ss->ssl3.hs.msg_len);
418                     if (rv == SECFailure)
419                         break; /* Skip rest of record */
420 
421 		    /* At this point we are advancing our state machine, so
422 		     * we can free our last flight of messages */
423 		    dtls_FreeHandshakeMessages(&ss->ssl3.hs.lastMessageFlight);
424 		    dtls_CancelTimer(ss);
425 
426 		    /* If there have been no retries this time, reset the
427 		     * timer value to the default per Section 4.2.4.1 */
428 		    if (ss->ssl3.hs.rtRetries == 0) {
429 			ss->ssl3.hs.rtTimeoutMs = INITIAL_DTLS_TIMEOUT_MS;
430 		    }
431                 }
432             }
433         }
434 
435 	buf.buf += fragment_length;
436         buf.len -= fragment_length;
437     }
438 
439     origBuf->len = 0;	/* So ssl3_GatherAppDataRecord will keep looping. */
440 
441     /* XXX OK for now. In future handle rv == SECWouldBlock safely in order
442      * to deal with asynchronous certificate verification */
443     return rv;
444 }
445 
446 /* Enqueue a message (either handshake or CCS)
447  *
448  * Called from:
449  *              dtls_StageHandshakeMessage()
450  *              ssl3_SendChangeCipherSpecs()
451  */
dtls_QueueMessage(sslSocket * ss,SSL3ContentType type,const SSL3Opaque * pIn,PRInt32 nIn)452 SECStatus dtls_QueueMessage(sslSocket *ss, SSL3ContentType type,
453     const SSL3Opaque *pIn, PRInt32 nIn)
454 {
455     SECStatus rv = SECSuccess;
456     DTLSQueuedMessage *msg = NULL;
457 
458     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
459     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
460 
461     msg = dtls_AllocQueuedMessage(ss->ssl3.cwSpec->epoch, type, pIn, nIn);
462 
463     if (!msg) {
464 	PORT_SetError(SEC_ERROR_NO_MEMORY);
465 	rv = SECFailure;
466     } else {
467 	PR_APPEND_LINK(&msg->link, &ss->ssl3.hs.lastMessageFlight);
468     }
469 
470     return rv;
471 }
472 
473 /* Add DTLS handshake message to the pending queue
474  * Empty the sendBuf buffer.
475  * This function returns SECSuccess or SECFailure, never SECWouldBlock.
476  * Always set sendBuf.len to 0, even when returning SECFailure.
477  *
478  * Called from:
479  *              ssl3_AppendHandshakeHeader()
480  *              dtls_FlushHandshake()
481  */
482 SECStatus
dtls_StageHandshakeMessage(sslSocket * ss)483 dtls_StageHandshakeMessage(sslSocket *ss)
484 {
485     SECStatus rv = SECSuccess;
486 
487     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
488     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
489 
490     /* This function is sometimes called when no data is actually to
491      * be staged, so just return SECSuccess. */
492     if (!ss->sec.ci.sendBuf.buf || !ss->sec.ci.sendBuf.len)
493 	return rv;
494 
495     rv = dtls_QueueMessage(ss, content_handshake,
496                            ss->sec.ci.sendBuf.buf, ss->sec.ci.sendBuf.len);
497 
498     /* Whether we succeeded or failed, toss the old handshake data. */
499     ss->sec.ci.sendBuf.len = 0;
500     return rv;
501 }
502 
503 /* Enqueue the handshake message in sendBuf (if any) and then
504  * transmit the resulting flight of handshake messages.
505  *
506  * Called from:
507  *              ssl3_FlushHandshake()
508  */
509 SECStatus
dtls_FlushHandshakeMessages(sslSocket * ss,PRInt32 flags)510 dtls_FlushHandshakeMessages(sslSocket *ss, PRInt32 flags)
511 {
512     SECStatus rv = SECSuccess;
513 
514     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
515     PORT_Assert(ss->opt.noLocks || ssl_HaveXmitBufLock(ss));
516 
517     rv = dtls_StageHandshakeMessage(ss);
518     if (rv != SECSuccess)
519         return rv;
520 
521     if (!(flags & ssl_SEND_FLAG_FORCE_INTO_BUFFER)) {
522         rv = dtls_TransmitMessageFlight(ss);
523         if (rv != SECSuccess)
524             return rv;
525 
526 	if (!(flags & ssl_SEND_FLAG_NO_RETRANSMIT)) {
527 	    ss->ssl3.hs.rtRetries = 0;
528 	    rv = dtls_StartTimer(ss, dtls_RetransmitTimerExpiredCb);
529 	}
530     }
531 
532     return rv;
533 }
534 
535 /* The callback for when the retransmit timer expires
536  *
537  * Called from:
538  *              dtls_CheckTimer()
539  *              dtls_HandleHandshake()
540  */
541 static void
dtls_RetransmitTimerExpiredCb(sslSocket * ss)542 dtls_RetransmitTimerExpiredCb(sslSocket *ss)
543 {
544     SECStatus rv = SECFailure;
545 
546     ss->ssl3.hs.rtRetries++;
547 
548     if (!(ss->ssl3.hs.rtRetries % 3)) {
549 	/* If one of the messages was potentially greater than > MTU,
550 	 * then downgrade. Do this every time we have retransmitted a
551 	 * message twice, per RFC 6347 Sec. 4.1.1 */
552 	dtls_SetMTU(ss, ss->ssl3.hs.maxMessageSent - 1);
553     }
554 
555     rv = dtls_TransmitMessageFlight(ss);
556     if (rv == SECSuccess) {
557 
558 	/* Re-arm the timer */
559 	rv = dtls_RestartTimer(ss, PR_TRUE, dtls_RetransmitTimerExpiredCb);
560     }
561 
562     if (rv == SECFailure) {
563 	/* XXX OK for now. In future maybe signal the stack that we couldn't
564 	 * transmit. For now, let the read handle any real network errors */
565     }
566 }
567 
568 /* Transmit a flight of handshake messages, stuffing them
569  * into as few records as seems reasonable
570  *
571  * Called from:
572  *             dtls_FlushHandshake()
573  *             dtls_RetransmitTimerExpiredCb()
574  */
575 static SECStatus
dtls_TransmitMessageFlight(sslSocket * ss)576 dtls_TransmitMessageFlight(sslSocket *ss)
577 {
578     SECStatus rv = SECSuccess;
579     PRCList *msg_p;
580     PRUint16 room_left = ss->ssl3.mtu;
581     PRInt32 sent;
582 
583     ssl_GetXmitBufLock(ss);
584     ssl_GetSpecReadLock(ss);
585 
586     /* DTLS does not buffer its handshake messages in
587      * ss->pendingBuf, but rather in the lastMessageFlight
588      * structure. This is just a sanity check that
589      * some programming error hasn't inadvertantly
590      * stuffed something in ss->pendingBuf
591      */
592     PORT_Assert(!ss->pendingBuf.len);
593     for (msg_p = PR_LIST_HEAD(&ss->ssl3.hs.lastMessageFlight);
594 	 msg_p != &ss->ssl3.hs.lastMessageFlight;
595 	 msg_p = PR_NEXT_LINK(msg_p)) {
596         DTLSQueuedMessage *msg = (DTLSQueuedMessage *)msg_p;
597 
598         /* The logic here is:
599          *
600 	 * 1. If this is a message that will not fit into the remaining
601 	 *    space, then flush.
602 	 * 2. If the message will now fit into the remaining space,
603          *    encrypt, buffer, and loop.
604          * 3. If the message will not fit, then fragment.
605          *
606 	 * At the end of the function, flush.
607          */
608         if ((msg->len + SSL3_BUFFER_FUDGE) > room_left) {
609 	    /* The message will not fit into the remaining space, so flush */
610 	    rv = dtls_SendSavedWriteData(ss);
611 	    if (rv != SECSuccess)
612 		break;
613 
614             room_left = ss->ssl3.mtu;
615 	}
616 
617         if ((msg->len + SSL3_BUFFER_FUDGE) <= room_left) {
618             /* The message will fit, so encrypt and then continue with the
619 	     * next packet */
620             sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
621 				   msg->data, msg->len,
622 				   ssl_SEND_FLAG_FORCE_INTO_BUFFER |
623 				   ssl_SEND_FLAG_USE_EPOCH);
624             if (sent != msg->len) {
625 		rv = SECFailure;
626 		if (sent != -1) {
627 		    PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
628 		}
629                 break;
630 	    }
631 
632             room_left = ss->ssl3.mtu - ss->pendingBuf.len;
633         } else {
634             /* The message will not fit, so fragment.
635              *
636 	     * XXX OK for now. Arrange to coalesce the last fragment
637 	     * of this message with the next message if possible.
638 	     * That would be more efficient.
639 	     */
640             PRUint32 fragment_offset = 0;
641             unsigned char fragment[DTLS_MAX_MTU]; /* >= than largest
642                                                    * plausible MTU */
643 
644 	    /* Assert that we have already flushed */
645 	    PORT_Assert(room_left == ss->ssl3.mtu);
646 
647             /* Case 3: We now need to fragment this message
648              * DTLS only supports fragmenting handshaking messages */
649             PORT_Assert(msg->type == content_handshake);
650 
651 	    /* The headers consume 12 bytes so the smalles possible
652 	     *  message (i.e., an empty one) is 12 bytes
653 	     */
654 	    PORT_Assert(msg->len >= 12);
655 
656             while ((fragment_offset + 12) < msg->len) {
657                 PRUint32 fragment_len;
658                 const unsigned char *content = msg->data + 12;
659                 PRUint32 content_len = msg->len - 12;
660 
661 		/* The reason we use 8 here is that that's the length of
662 		 * the new DTLS data that we add to the header */
663                 fragment_len = PR_MIN(room_left - (SSL3_BUFFER_FUDGE + 8),
664                                       content_len - fragment_offset);
665 		PORT_Assert(fragment_len < DTLS_MAX_MTU - 12);
666 		/* Make totally sure that we are within the buffer.
667 		 * Note that the only way that fragment len could get
668 		 * adjusted here is if
669                  *
670 		 * (a) we are in release mode so the PORT_Assert is compiled out
671 		 * (b) either the MTU table is inconsistent with DTLS_MAX_MTU
672 		 * or ss->ssl3.mtu has become corrupt.
673 		 */
674 		fragment_len = PR_MIN(fragment_len, DTLS_MAX_MTU - 12);
675 
676                 /* Construct an appropriate-sized fragment */
677                 /* Type, length, sequence */
678                 PORT_Memcpy(fragment, msg->data, 6);
679 
680                 /* Offset */
681                 fragment[6] = (fragment_offset >> 16) & 0xff;
682                 fragment[7] = (fragment_offset >> 8) & 0xff;
683                 fragment[8] = (fragment_offset) & 0xff;
684 
685                 /* Fragment length */
686                 fragment[9] = (fragment_len >> 16) & 0xff;
687                 fragment[10] = (fragment_len >> 8) & 0xff;
688                 fragment[11] = (fragment_len) & 0xff;
689 
690                 PORT_Memcpy(fragment + 12, content + fragment_offset,
691                             fragment_len);
692 
693                 /*
694 		 *  Send the record. We do this in two stages
695 		 * 1. Encrypt
696 		 */
697                 sent = ssl3_SendRecord(ss, msg->epoch, msg->type,
698                                        fragment, fragment_len + 12,
699                                        ssl_SEND_FLAG_FORCE_INTO_BUFFER |
700 				       ssl_SEND_FLAG_USE_EPOCH);
701                 if (sent != (fragment_len + 12)) {
702 		    rv = SECFailure;
703 		    if (sent != -1) {
704 			PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
705 		    }
706 		    break;
707 		}
708 
709 		/* 2. Flush */
710 		rv = dtls_SendSavedWriteData(ss);
711 		if (rv != SECSuccess)
712 		    break;
713 
714                 fragment_offset += fragment_len;
715             }
716         }
717     }
718 
719     /* Finally, we need to flush */
720     if (rv == SECSuccess)
721 	rv = dtls_SendSavedWriteData(ss);
722 
723     /* Give up the locks */
724     ssl_ReleaseSpecReadLock(ss);
725     ssl_ReleaseXmitBufLock(ss);
726 
727     return rv;
728 }
729 
730 /* Flush the data in the pendingBuf and update the max message sent
731  * so we can adjust the MTU estimate if we need to.
732  * Wrapper for ssl_SendSavedWriteData.
733  *
734  * Called from dtls_TransmitMessageFlight()
735  */
736 static
dtls_SendSavedWriteData(sslSocket * ss)737 SECStatus dtls_SendSavedWriteData(sslSocket *ss)
738 {
739     PRInt32 sent;
740 
741     sent = ssl_SendSavedWriteData(ss);
742     if (sent < 0)
743 	return SECFailure;
744 
745     /* We should always have complete writes b/c datagram sockets
746      * don't really block */
747     if (ss->pendingBuf.len > 0) {
748 	ssl_MapLowLevelError(SSL_ERROR_SOCKET_WRITE_FAILURE);
749     	return SECFailure;
750     }
751 
752     /* Update the largest message sent so we can adjust the MTU
753      * estimate if necessary */
754     if (sent > ss->ssl3.hs.maxMessageSent)
755 	ss->ssl3.hs.maxMessageSent = sent;
756 
757     return SECSuccess;
758 }
759 
760 /* Compress, MAC, encrypt a DTLS record. Allows specification of
761  * the epoch using epoch value. If use_epoch is PR_TRUE then
762  * we use the provided epoch. If use_epoch is PR_FALSE then
763  * whatever the current value is in effect is used.
764  *
765  * Called from ssl3_SendRecord()
766  */
767 SECStatus
dtls_CompressMACEncryptRecord(sslSocket * ss,DTLSEpoch epoch,PRBool use_epoch,SSL3ContentType type,const SSL3Opaque * pIn,PRUint32 contentLen,sslBuffer * wrBuf)768 dtls_CompressMACEncryptRecord(sslSocket *        ss,
769                               DTLSEpoch          epoch,
770 			      PRBool             use_epoch,
771                               SSL3ContentType    type,
772 		              const SSL3Opaque * pIn,
773 		              PRUint32           contentLen,
774 			      sslBuffer        * wrBuf)
775 {
776     SECStatus rv = SECFailure;
777     ssl3CipherSpec *          cwSpec;
778 
779     ssl_GetSpecReadLock(ss);	/********************************/
780 
781     /* The reason for this switch-hitting code is that we might have
782      * a flight of records spanning an epoch boundary, e.g.,
783      *
784      * ClientKeyExchange (epoch = 0)
785      * ChangeCipherSpec (epoch = 0)
786      * Finished (epoch = 1)
787      *
788      * Thus, each record needs a different cipher spec. The information
789      * about which epoch to use is carried with the record.
790      */
791     if (use_epoch) {
792 	if (ss->ssl3.cwSpec->epoch == epoch)
793 	    cwSpec = ss->ssl3.cwSpec;
794 	else if (ss->ssl3.pwSpec->epoch == epoch)
795 	    cwSpec = ss->ssl3.pwSpec;
796 	else
797 	    cwSpec = NULL;
798     } else {
799 	cwSpec = ss->ssl3.cwSpec;
800     }
801 
802     if (cwSpec) {
803         rv = ssl3_CompressMACEncryptRecord(cwSpec, ss->sec.isServer, PR_TRUE,
804 					   PR_FALSE, type, pIn, contentLen,
805 					   wrBuf);
806     } else {
807         PR_NOT_REACHED("Couldn't find a cipher spec matching epoch");
808 	PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
809     }
810     ssl_ReleaseSpecReadLock(ss); /************************************/
811 
812     return rv;
813 }
814 
815 /* Start a timer
816  *
817  * Called from:
818  *             dtls_HandleHandshake()
819  *             dtls_FlushHAndshake()
820  *             dtls_RestartTimer()
821  */
822 SECStatus
dtls_StartTimer(sslSocket * ss,DTLSTimerCb cb)823 dtls_StartTimer(sslSocket *ss, DTLSTimerCb cb)
824 {
825     PORT_Assert(ss->ssl3.hs.rtTimerCb == NULL);
826 
827     ss->ssl3.hs.rtTimerStarted = PR_IntervalNow();
828     ss->ssl3.hs.rtTimerCb = cb;
829 
830     return SECSuccess;
831 }
832 
833 /* Restart a timer with optional backoff
834  *
835  * Called from dtls_RetransmitTimerExpiredCb()
836  */
837 SECStatus
dtls_RestartTimer(sslSocket * ss,PRBool backoff,DTLSTimerCb cb)838 dtls_RestartTimer(sslSocket *ss, PRBool backoff, DTLSTimerCb cb)
839 {
840     if (backoff) {
841 	ss->ssl3.hs.rtTimeoutMs *= 2;
842 	if (ss->ssl3.hs.rtTimeoutMs > MAX_DTLS_TIMEOUT_MS)
843 	    ss->ssl3.hs.rtTimeoutMs = MAX_DTLS_TIMEOUT_MS;
844     }
845 
846     return dtls_StartTimer(ss, cb);
847 }
848 
849 /* Cancel a pending timer
850  *
851  * Called from:
852  *              dtls_HandleHandshake()
853  *              dtls_CheckTimer()
854  */
855 void
dtls_CancelTimer(sslSocket * ss)856 dtls_CancelTimer(sslSocket *ss)
857 {
858     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
859 
860     ss->ssl3.hs.rtTimerCb = NULL;
861 }
862 
863 /* Check the pending timer and fire the callback if it expired
864  *
865  * Called from ssl3_GatherCompleteHandshake()
866  */
867 void
dtls_CheckTimer(sslSocket * ss)868 dtls_CheckTimer(sslSocket *ss)
869 {
870     if (!ss->ssl3.hs.rtTimerCb)
871 	return;
872 
873     if ((PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted) >
874 	PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs)) {
875 	/* Timer has expired */
876 	DTLSTimerCb cb = ss->ssl3.hs.rtTimerCb;
877 
878 	/* Cancel the timer so that we can call the CB safely */
879 	dtls_CancelTimer(ss);
880 
881 	/* Now call the CB */
882 	cb(ss);
883     }
884 }
885 
886 /* The callback to fire when the holddown timer for the Finished
887  * message expires and we can delete it
888  *
889  * Called from dtls_CheckTimer()
890  */
891 void
dtls_FinishedTimerCb(sslSocket * ss)892 dtls_FinishedTimerCb(sslSocket *ss)
893 {
894     ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
895 }
896 
897 /* Cancel the Finished hold-down timer and destroy the
898  * pending cipher spec. Note that this means that
899  * successive rehandshakes will fail if the Finished is
900  * lost.
901  *
902  * XXX OK for now. Figure out how to handle the combination
903  * of Finished lost and rehandshake
904  */
905 void
dtls_RehandshakeCleanup(sslSocket * ss)906 dtls_RehandshakeCleanup(sslSocket *ss)
907 {
908     dtls_CancelTimer(ss);
909     ssl3_DestroyCipherSpec(ss->ssl3.pwSpec, PR_FALSE);
910     ss->ssl3.hs.sendMessageSeq = 0;
911     ss->ssl3.hs.recvMessageSeq = 0;
912 }
913 
914 /* Set the MTU to the next step less than or equal to the
915  * advertised value. Also used to downgrade the MTU by
916  * doing dtls_SetMTU(ss, biggest packet set).
917  *
918  * Passing 0 means set this to the largest MTU known
919  * (effectively resetting the PMTU backoff value).
920  *
921  * Called by:
922  *            ssl3_InitState()
923  *            dtls_RetransmitTimerExpiredCb()
924  */
925 void
dtls_SetMTU(sslSocket * ss,PRUint16 advertised)926 dtls_SetMTU(sslSocket *ss, PRUint16 advertised)
927 {
928     int i;
929 
930     if (advertised == 0) {
931 	ss->ssl3.mtu = COMMON_MTU_VALUES[0];
932 	SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
933 	return;
934     }
935 
936     for (i = 0; i < PR_ARRAY_SIZE(COMMON_MTU_VALUES); i++) {
937 	if (COMMON_MTU_VALUES[i] <= advertised) {
938 	    ss->ssl3.mtu = COMMON_MTU_VALUES[i];
939 	    SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
940 	    return;
941 	}
942     }
943 
944     /* Fallback */
945     ss->ssl3.mtu = COMMON_MTU_VALUES[PR_ARRAY_SIZE(COMMON_MTU_VALUES)-1];
946     SSL_TRC(30, ("Resetting MTU to %d", ss->ssl3.mtu));
947 }
948 
949 /* Called from ssl3_HandleHandshakeMessage() when it has deciphered a
950  * DTLS hello_verify_request
951  * Caller must hold Handshake and RecvBuf locks.
952  */
953 SECStatus
dtls_HandleHelloVerifyRequest(sslSocket * ss,SSL3Opaque * b,PRUint32 length)954 dtls_HandleHelloVerifyRequest(sslSocket *ss, SSL3Opaque *b, PRUint32 length)
955 {
956     int                 errCode	= SSL_ERROR_RX_MALFORMED_HELLO_VERIFY_REQUEST;
957     SECStatus           rv;
958     PRInt32             temp;
959     SECItem             cookie = {siBuffer, NULL, 0};
960     SSL3AlertDescription desc   = illegal_parameter;
961 
962     SSL_TRC(3, ("%d: SSL3[%d]: handle hello_verify_request handshake",
963     	SSL_GETPID(), ss->fd));
964     PORT_Assert(ss->opt.noLocks || ssl_HaveRecvBufLock(ss));
965     PORT_Assert(ss->opt.noLocks || ssl_HaveSSL3HandshakeLock(ss));
966 
967     if (ss->ssl3.hs.ws != wait_server_hello) {
968         errCode = SSL_ERROR_RX_UNEXPECTED_HELLO_VERIFY_REQUEST;
969 	desc    = unexpected_message;
970 	goto alert_loser;
971     }
972 
973     /* The version */
974     temp = ssl3_ConsumeHandshakeNumber(ss, 2, &b, &length);
975     if (temp < 0) {
976     	goto loser; 	/* alert has been sent */
977     }
978 
979     if (temp != SSL_LIBRARY_VERSION_DTLS_1_0_WIRE) {
980 	/* Note: this will need adjustment for DTLS 1.2 per Section 4.2.1 */
981 	goto alert_loser;
982     }
983 
984     /* The cookie */
985     rv = ssl3_ConsumeHandshakeVariable(ss, &cookie, 1, &b, &length);
986     if (rv != SECSuccess) {
987     	goto loser; 	/* alert has been sent */
988     }
989     if (cookie.len > DTLS_COOKIE_BYTES) {
990 	desc = decode_error;
991 	goto alert_loser;	/* malformed. */
992     }
993 
994     PORT_Memcpy(ss->ssl3.hs.cookie, cookie.data, cookie.len);
995     ss->ssl3.hs.cookieLen = cookie.len;
996 
997 
998     ssl_GetXmitBufLock(ss);		/*******************************/
999 
1000     /* Now re-send the client hello */
1001     rv = ssl3_SendClientHello(ss, PR_TRUE);
1002 
1003     ssl_ReleaseXmitBufLock(ss);		/*******************************/
1004 
1005     if (rv == SECSuccess)
1006 	return rv;
1007 
1008 alert_loser:
1009     (void)SSL3_SendAlert(ss, alert_fatal, desc);
1010 
1011 loser:
1012     errCode = ssl_MapLowLevelError(errCode);
1013     return SECFailure;
1014 }
1015 
1016 /* Initialize the DTLS anti-replay window
1017  *
1018  * Called from:
1019  *              ssl3_SetupPendingCipherSpec()
1020  *              ssl3_InitCipherSpec()
1021  */
1022 void
dtls_InitRecvdRecords(DTLSRecvdRecords * records)1023 dtls_InitRecvdRecords(DTLSRecvdRecords *records)
1024 {
1025     PORT_Memset(records->data, 0, sizeof(records->data));
1026     records->left = 0;
1027     records->right = DTLS_RECVD_RECORDS_WINDOW - 1;
1028 }
1029 
1030 /*
1031  * Has this DTLS record been received? Return values are:
1032  * -1 -- out of range to the left
1033  *  0 -- not received yet
1034  *  1 -- replay
1035  *
1036  *  Called from: dtls_HandleRecord()
1037  */
1038 int
dtls_RecordGetRecvd(DTLSRecvdRecords * records,PRUint64 seq)1039 dtls_RecordGetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1040 {
1041     PRUint64 offset;
1042 
1043     /* Out of range to the left */
1044     if (seq < records->left) {
1045 	return -1;
1046     }
1047 
1048     /* Out of range to the right; since we advance the window on
1049      * receipt, that means that this packet has not been received
1050      * yet */
1051     if (seq > records->right)
1052 	return 0;
1053 
1054     offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1055 
1056     return !!(records->data[offset / 8] & (1 << (offset % 8)));
1057 }
1058 
1059 /* Update the DTLS anti-replay window
1060  *
1061  * Called from ssl3_HandleRecord()
1062  */
1063 void
dtls_RecordSetRecvd(DTLSRecvdRecords * records,PRUint64 seq)1064 dtls_RecordSetRecvd(DTLSRecvdRecords *records, PRUint64 seq)
1065 {
1066     PRUint64 offset;
1067 
1068     if (seq < records->left)
1069 	return;
1070 
1071     if (seq > records->right) {
1072 	PRUint64 new_left;
1073 	PRUint64 new_right;
1074 	PRUint64 right;
1075 
1076 	/* Slide to the right; this is the tricky part
1077          *
1078 	 * 1. new_top is set to have room for seq, on the
1079 	 *    next byte boundary by setting the right 8
1080 	 *    bits of seq
1081          * 2. new_left is set to compensate.
1082          * 3. Zero all bits between top and new_top. Since
1083          *    this is a ring, this zeroes everything as-yet
1084 	 *    unseen. Because we always operate on byte
1085 	 *    boundaries, we can zero one byte at a time
1086 	 */
1087 	new_right = seq | 0x07;
1088 	new_left = (new_right - DTLS_RECVD_RECORDS_WINDOW) + 1;
1089 
1090 	for (right = records->right + 8; right <= new_right; right += 8) {
1091 	    offset = right % DTLS_RECVD_RECORDS_WINDOW;
1092 	    records->data[offset / 8] = 0;
1093 	}
1094 
1095 	records->right = new_right;
1096 	records->left = new_left;
1097     }
1098 
1099     offset = seq % DTLS_RECVD_RECORDS_WINDOW;
1100 
1101     records->data[offset / 8] |= (1 << (offset % 8));
1102 }
1103 
1104 SECStatus
DTLS_GetHandshakeTimeout(PRFileDesc * socket,PRIntervalTime * timeout)1105 DTLS_GetHandshakeTimeout(PRFileDesc *socket, PRIntervalTime *timeout)
1106 {
1107     sslSocket * ss = NULL;
1108     PRIntervalTime elapsed;
1109     PRIntervalTime desired;
1110 
1111     ss = ssl_FindSocket(socket);
1112 
1113     if (!ss)
1114         return SECFailure;
1115 
1116     if (!IS_DTLS(ss))
1117         return SECFailure;
1118 
1119     if (!ss->ssl3.hs.rtTimerCb)
1120         return SECFailure;
1121 
1122     elapsed = PR_IntervalNow() - ss->ssl3.hs.rtTimerStarted;
1123     desired = PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs);
1124     if (elapsed > desired) {
1125         /* Timer expired */
1126         *timeout = PR_INTERVAL_NO_WAIT;
1127     } else {
1128         *timeout = desired - elapsed;
1129     }
1130 
1131     return SECSuccess;
1132 }
1133