• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// DTLS implementation.
6//
7// NOTE: This is a not even a remotely production-quality DTLS
8// implementation. It is the bare minimum necessary to be able to
9// achieve coverage on BoringSSL's implementation. Of note is that
10// this implementation assumes the underlying net.PacketConn is not
11// only reliable but also ordered. BoringSSL will be expected to deal
12// with simulated loss, but there is no point in forcing the test
13// driver to.
14
15package runner
16
17import (
18	"bytes"
19	"errors"
20	"fmt"
21	"io"
22	"math/rand"
23	"net"
24)
25
26func versionToWire(vers uint16, isDTLS bool) uint16 {
27	if isDTLS {
28		return ^(vers - 0x0201)
29	}
30	return vers
31}
32
33func wireToVersion(vers uint16, isDTLS bool) uint16 {
34	if isDTLS {
35		return ^vers + 0x0201
36	}
37	return vers
38}
39
40func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
41	recordHeaderLen := dtlsRecordHeaderLen
42
43	if c.rawInput == nil {
44		c.rawInput = c.in.newBlock()
45	}
46	b := c.rawInput
47
48	// Read a new packet only if the current one is empty.
49	if len(b.data) == 0 {
50		// Pick some absurdly large buffer size.
51		b.resize(maxCiphertext + recordHeaderLen)
52		n, err := c.conn.Read(c.rawInput.data)
53		if err != nil {
54			return 0, nil, err
55		}
56		if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
57			return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
58		}
59		c.rawInput.resize(n)
60	}
61
62	// Read out one record.
63	//
64	// A real DTLS implementation should be tolerant of errors,
65	// but this is test code. We should not be tolerant of our
66	// peer sending garbage.
67	if len(b.data) < recordHeaderLen {
68		return 0, nil, errors.New("dtls: failed to read record header")
69	}
70	typ := recordType(b.data[0])
71	vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
72	if c.haveVers {
73		if vers != c.vers {
74			c.sendAlert(alertProtocolVersion)
75			return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
76		}
77	} else {
78		if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
79			c.sendAlert(alertProtocolVersion)
80			return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
81		}
82	}
83	epoch := b.data[3:5]
84	seq := b.data[5:11]
85	// For test purposes, require the sequence number be monotonically
86	// increasing, so c.in includes the minimum next sequence number. Gaps
87	// may occur if packets failed to be sent out. A real implementation
88	// would maintain a replay window and such.
89	if !bytes.Equal(epoch, c.in.seq[:2]) {
90		c.sendAlert(alertIllegalParameter)
91		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
92	}
93	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
94		c.sendAlert(alertIllegalParameter)
95		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
96	}
97	copy(c.in.seq[2:], seq)
98	n := int(b.data[11])<<8 | int(b.data[12])
99	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
100		c.sendAlert(alertRecordOverflow)
101		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
102	}
103
104	// Process message.
105	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
106	ok, off, err := c.in.decrypt(b)
107	if !ok {
108		c.in.setErrorLocked(c.sendAlert(err))
109	}
110	b.off = off
111	return typ, b, nil
112}
113
114func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
115	fragment := make([]byte, 0, 12+fragLen)
116	fragment = append(fragment, header...)
117	fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
118	fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
119	fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
120	fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
121	return fragment
122}
123
124func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
125	if typ != recordTypeHandshake {
126		// Only handshake messages are fragmented.
127		return c.dtlsWriteRawRecord(typ, data)
128	}
129
130	maxLen := c.config.Bugs.MaxHandshakeRecordLength
131	if maxLen <= 0 {
132		maxLen = 1024
133	}
134
135	// Handshake messages have to be modified to include fragment
136	// offset and length and with the header replicated. Save the
137	// TLS header here.
138	//
139	// TODO(davidben): This assumes that data contains exactly one
140	// handshake message. This is incompatible with
141	// FragmentAcrossChangeCipherSpec. (Which is unfortunate
142	// because OpenSSL's DTLS implementation will probably accept
143	// such fragmentation and could do with a fix + tests.)
144	header := data[:4]
145	data = data[4:]
146
147	isFinished := header[0] == typeFinished
148
149	if c.config.Bugs.SendEmptyFragments {
150		fragment := c.makeFragment(header, data, 0, 0)
151		c.pendingFragments = append(c.pendingFragments, fragment)
152	}
153
154	firstRun := true
155	fragOffset := 0
156	for firstRun || fragOffset < len(data) {
157		firstRun = false
158		fragLen := len(data) - fragOffset
159		if fragLen > maxLen {
160			fragLen = maxLen
161		}
162
163		fragment := c.makeFragment(header, data, fragOffset, fragLen)
164		if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
165			fragment[0]++
166		}
167		if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
168			fragment[3]++
169		}
170
171		// Buffer the fragment for later. They will be sent (and
172		// reordered) on flush.
173		c.pendingFragments = append(c.pendingFragments, fragment)
174		if c.config.Bugs.ReorderHandshakeFragments {
175			// Don't duplicate Finished to avoid the peer
176			// interpreting it as a retransmit request.
177			if !isFinished {
178				c.pendingFragments = append(c.pendingFragments, fragment)
179			}
180
181			if fragLen > (maxLen+1)/2 {
182				// Overlap each fragment by half.
183				fragLen = (maxLen + 1) / 2
184			}
185		}
186		fragOffset += fragLen
187		n += fragLen
188	}
189	if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments {
190		fragment := c.makeFragment(header, data, 0, len(data))
191		c.pendingFragments = append(c.pendingFragments, fragment)
192	}
193
194	// Increment the handshake sequence number for the next
195	// handshake message.
196	c.sendHandshakeSeq++
197	return
198}
199
200func (c *Conn) dtlsFlushHandshake() error {
201	if !c.isDTLS {
202		return nil
203	}
204
205	// This is a test-only DTLS implementation, so there is no need to
206	// retain |c.pendingFragments| for a future retransmit.
207	var fragments [][]byte
208	fragments, c.pendingFragments = c.pendingFragments, fragments
209
210	if c.config.Bugs.ReorderHandshakeFragments {
211		perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
212		tmp := make([][]byte, len(fragments))
213		for i := range tmp {
214			tmp[i] = fragments[perm[i]]
215		}
216		fragments = tmp
217	}
218
219	maxRecordLen := c.config.Bugs.PackHandshakeFragments
220	maxPacketLen := c.config.Bugs.PackHandshakeRecords
221
222	// Pack handshake fragments into records.
223	var records [][]byte
224	for _, fragment := range fragments {
225		if n := c.config.Bugs.SplitFragments; n > 0 {
226			if len(fragment) > n {
227				records = append(records, fragment[:n])
228				records = append(records, fragment[n:])
229			} else {
230				records = append(records, fragment)
231			}
232		} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
233			records[i] = append(records[i], fragment...)
234		} else {
235			// The fragment will be appended to, so copy it.
236			records = append(records, append([]byte{}, fragment...))
237		}
238	}
239
240	// Format them into packets.
241	var packets [][]byte
242	for _, record := range records {
243		b, err := c.dtlsSealRecord(recordTypeHandshake, record)
244		if err != nil {
245			return err
246		}
247
248		if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen {
249			packets[i] = append(packets[i], b.data...)
250		} else {
251			// The sealed record will be appended to and reused by
252			// |c.out|, so copy it.
253			packets = append(packets, append([]byte{}, b.data...))
254		}
255		c.out.freeBlock(b)
256	}
257
258	// Send all the packets.
259	for _, packet := range packets {
260		if _, err := c.conn.Write(packet); err != nil {
261			return err
262		}
263	}
264	return nil
265}
266
267// dtlsSealRecord seals a record into a block from |c.out|'s pool.
268func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) {
269	recordHeaderLen := dtlsRecordHeaderLen
270	maxLen := c.config.Bugs.MaxHandshakeRecordLength
271	if maxLen <= 0 {
272		maxLen = 1024
273	}
274
275	b = c.out.newBlock()
276
277	explicitIVLen := 0
278	explicitIVIsSeq := false
279
280	if cbc, ok := c.out.cipher.(cbcMode); ok {
281		// Block cipher modes have an explicit IV.
282		explicitIVLen = cbc.BlockSize()
283	} else if aead, ok := c.out.cipher.(*tlsAead); ok {
284		if aead.explicitNonce {
285			explicitIVLen = 8
286			// The AES-GCM construction in TLS has an explicit nonce so that
287			// the nonce can be random. However, the nonce is only 8 bytes
288			// which is too small for a secure, random nonce. Therefore we
289			// use the sequence number as the nonce.
290			explicitIVIsSeq = true
291		}
292	} else if c.out.cipher != nil {
293		panic("Unknown cipher")
294	}
295	b.resize(recordHeaderLen + explicitIVLen + len(data))
296	b.data[0] = byte(typ)
297	vers := c.vers
298	if vers == 0 {
299		// Some TLS servers fail if the record version is greater than
300		// TLS 1.0 for the initial ClientHello.
301		vers = VersionTLS10
302	}
303	vers = versionToWire(vers, c.isDTLS)
304	b.data[1] = byte(vers >> 8)
305	b.data[2] = byte(vers)
306	// DTLS records include an explicit sequence number.
307	copy(b.data[3:11], c.out.outSeq[0:])
308	b.data[11] = byte(len(data) >> 8)
309	b.data[12] = byte(len(data))
310	if explicitIVLen > 0 {
311		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
312		if explicitIVIsSeq {
313			copy(explicitIV, c.out.outSeq[:])
314		} else {
315			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
316				return
317			}
318		}
319	}
320	copy(b.data[recordHeaderLen+explicitIVLen:], data)
321	c.out.encrypt(b, explicitIVLen)
322	return
323}
324
325func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
326	b, err := c.dtlsSealRecord(typ, data)
327	if err != nil {
328		return
329	}
330
331	_, err = c.conn.Write(b.data)
332	if err != nil {
333		return
334	}
335	n = len(data)
336
337	c.out.freeBlock(b)
338
339	if typ == recordTypeChangeCipherSpec {
340		err = c.out.changeCipherSpec(c.config)
341		if err != nil {
342			// Cannot call sendAlert directly,
343			// because we already hold c.out.Mutex.
344			c.tmp[0] = alertLevelError
345			c.tmp[1] = byte(err.(alert))
346			c.writeRecord(recordTypeAlert, c.tmp[0:2])
347			return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
348		}
349	}
350	return
351}
352
353func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
354	// Assemble a full handshake message.  For test purposes, this
355	// implementation assumes fragments arrive in order. It may
356	// need to be cleverer if we ever test BoringSSL's retransmit
357	// behavior.
358	for len(c.handMsg) < 4+c.handMsgLen {
359		// Get a new handshake record if the previous has been
360		// exhausted.
361		if c.hand.Len() == 0 {
362			if err := c.in.err; err != nil {
363				return nil, err
364			}
365			if err := c.readRecord(recordTypeHandshake); err != nil {
366				return nil, err
367			}
368		}
369
370		// Read the next fragment. It must fit entirely within
371		// the record.
372		if c.hand.Len() < 12 {
373			return nil, errors.New("dtls: bad handshake record")
374		}
375		header := c.hand.Next(12)
376		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
377		fragSeq := uint16(header[4])<<8 | uint16(header[5])
378		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
379		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
380
381		if c.hand.Len() < fragLen {
382			return nil, errors.New("dtls: fragment length too long")
383		}
384		fragment := c.hand.Next(fragLen)
385
386		// Check it's a fragment for the right message.
387		if fragSeq != c.recvHandshakeSeq {
388			return nil, errors.New("dtls: bad handshake sequence number")
389		}
390
391		// Check that the length is consistent.
392		if c.handMsg == nil {
393			c.handMsgLen = fragN
394			if c.handMsgLen > maxHandshake {
395				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
396			}
397			// Start with the TLS handshake header,
398			// without the DTLS bits.
399			c.handMsg = append([]byte{}, header[:4]...)
400		} else if fragN != c.handMsgLen {
401			return nil, errors.New("dtls: bad handshake length")
402		}
403
404		// Add the fragment to the pending message.
405		if 4+fragOff != len(c.handMsg) {
406			return nil, errors.New("dtls: bad fragment offset")
407		}
408		if fragOff+fragLen > c.handMsgLen {
409			return nil, errors.New("dtls: bad fragment length")
410		}
411		c.handMsg = append(c.handMsg, fragment...)
412	}
413	c.recvHandshakeSeq++
414	ret := c.handMsg
415	c.handMsg, c.handMsgLen = nil, 0
416	return ret, nil
417}
418
419// DTLSServer returns a new DTLS server side connection
420// using conn as the underlying transport.
421// The configuration config must be non-nil and must have
422// at least one certificate.
423func DTLSServer(conn net.Conn, config *Config) *Conn {
424	c := &Conn{config: config, isDTLS: true, conn: conn}
425	c.init()
426	return c
427}
428
429// DTLSClient returns a new DTLS client side connection
430// using conn as the underlying transport.
431// The config cannot be nil: users must set either ServerHostname or
432// InsecureSkipVerify in the config.
433func DTLSClient(conn net.Conn, config *Config) *Conn {
434	c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
435	c.init()
436	return c
437}
438