• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// DTLS implementation.
6//
7// NOTE: This is a not even a remotely production-quality DTLS
8// implementation. It is the bare minimum necessary to be able to
9// achieve coverage on BoringSSL's implementation. Of note is that
10// this implementation assumes the underlying net.PacketConn is not
11// only reliable but also ordered. BoringSSL will be expected to deal
12// with simulated loss, but there is no point in forcing the test
13// driver to.
14
15package main
16
17import (
18	"bytes"
19	"crypto/cipher"
20	"errors"
21	"fmt"
22	"io"
23	"net"
24)
25
26func versionToWire(vers uint16, isDTLS bool) uint16 {
27	if isDTLS {
28		return ^(vers - 0x0201)
29	}
30	return vers
31}
32
33func wireToVersion(vers uint16, isDTLS bool) uint16 {
34	if isDTLS {
35		return ^vers + 0x0201
36	}
37	return vers
38}
39
40func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
41	recordHeaderLen := dtlsRecordHeaderLen
42
43	if c.rawInput == nil {
44		c.rawInput = c.in.newBlock()
45	}
46	b := c.rawInput
47
48	// Read a new packet only if the current one is empty.
49	if len(b.data) == 0 {
50		// Pick some absurdly large buffer size.
51		b.resize(maxCiphertext + recordHeaderLen)
52		n, err := c.conn.Read(c.rawInput.data)
53		if err != nil {
54			return 0, nil, err
55		}
56		c.rawInput.resize(n)
57	}
58
59	// Read out one record.
60	//
61	// A real DTLS implementation should be tolerant of errors,
62	// but this is test code. We should not be tolerant of our
63	// peer sending garbage.
64	if len(b.data) < recordHeaderLen {
65		return 0, nil, errors.New("dtls: failed to read record header")
66	}
67	typ := recordType(b.data[0])
68	vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
69	if c.haveVers && vers != c.vers {
70		c.sendAlert(alertProtocolVersion)
71		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
72	}
73	seq := b.data[3:11]
74	// For test purposes, we assume a reliable channel. Require
75	// that the explicit sequence number matches the incrementing
76	// one we maintain. A real implementation would maintain a
77	// replay window and such.
78	if !bytes.Equal(seq, c.in.seq[:]) {
79		c.sendAlert(alertIllegalParameter)
80		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
81	}
82	n := int(b.data[11])<<8 | int(b.data[12])
83	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
84		c.sendAlert(alertRecordOverflow)
85		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
86	}
87
88	// Process message.
89	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
90	ok, off, err := c.in.decrypt(b)
91	if !ok {
92		c.in.setErrorLocked(c.sendAlert(err))
93	}
94	b.off = off
95	return typ, b, nil
96}
97
98func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
99	recordHeaderLen := dtlsRecordHeaderLen
100	maxLen := c.config.Bugs.MaxHandshakeRecordLength
101	if maxLen <= 0 {
102		maxLen = 1024
103	}
104
105	b := c.out.newBlock()
106
107	var header []byte
108	if typ == recordTypeHandshake {
109		// Handshake messages have to be modified to include
110		// fragment offset and length and with the header
111		// replicated. Save the header here.
112		//
113		// TODO(davidben): This assumes that data contains
114		// exactly one handshake message. This is incompatible
115		// with FragmentAcrossChangeCipherSpec. (Which is
116		// unfortunate because OpenSSL's DTLS implementation
117		// will probably accept such fragmentation and could
118		// do with a fix + tests.)
119		if len(data) < 4 {
120			// This should not happen.
121			panic(data)
122		}
123		header = data[:4]
124		data = data[4:]
125	}
126
127	firstRun := true
128	for firstRun || len(data) > 0 {
129		firstRun = false
130		m := len(data)
131		var fragment []byte
132		// Handshake messages get fragmented. Other records we
133		// pass-through as is. DTLS should be a packet
134		// interface.
135		if typ == recordTypeHandshake {
136			if m > maxLen {
137				m = maxLen
138			}
139
140			// Standard handshake header.
141			fragment = make([]byte, 0, 12+m)
142			fragment = append(fragment, header...)
143			// message_seq
144			fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
145			// fragment_offset
146			fragment = append(fragment, byte(n>>16), byte(n>>8), byte(n))
147			// fragment_length
148			fragment = append(fragment, byte(m>>16), byte(m>>8), byte(m))
149			fragment = append(fragment, data[:m]...)
150		} else {
151			fragment = data[:m]
152		}
153
154		// Send the fragment.
155		explicitIVLen := 0
156		explicitIVIsSeq := false
157
158		if cbc, ok := c.out.cipher.(cbcMode); ok {
159			// Block cipher modes have an explicit IV.
160			explicitIVLen = cbc.BlockSize()
161		} else if _, ok := c.out.cipher.(cipher.AEAD); ok {
162			explicitIVLen = 8
163			// The AES-GCM construction in TLS has an
164			// explicit nonce so that the nonce can be
165			// random. However, the nonce is only 8 bytes
166			// which is too small for a secure, random
167			// nonce. Therefore we use the sequence number
168			// as the nonce.
169			explicitIVIsSeq = true
170		} else if c.out.cipher != nil {
171			panic("Unknown cipher")
172		}
173		b.resize(recordHeaderLen + explicitIVLen + len(fragment))
174		b.data[0] = byte(typ)
175		vers := c.vers
176		if vers == 0 {
177			// Some TLS servers fail if the record version is
178			// greater than TLS 1.0 for the initial ClientHello.
179			vers = VersionTLS10
180		}
181		vers = versionToWire(vers, c.isDTLS)
182		b.data[1] = byte(vers >> 8)
183		b.data[2] = byte(vers)
184		// DTLS records include an explicit sequence number.
185		copy(b.data[3:11], c.out.seq[0:])
186		b.data[11] = byte(len(fragment) >> 8)
187		b.data[12] = byte(len(fragment))
188		if explicitIVLen > 0 {
189			explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
190			if explicitIVIsSeq {
191				copy(explicitIV, c.out.seq[:])
192			} else {
193				if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
194					break
195				}
196			}
197		}
198		copy(b.data[recordHeaderLen+explicitIVLen:], fragment)
199		c.out.encrypt(b, explicitIVLen)
200
201		// TODO(davidben): A real DTLS implementation needs to
202		// retransmit handshake messages. For testing
203		// purposes, we don't actually care.
204		_, err = c.conn.Write(b.data)
205		if err != nil {
206			break
207		}
208		n += m
209		data = data[m:]
210	}
211	c.out.freeBlock(b)
212
213	// Increment the handshake sequence number for the next
214	// handshake message.
215	if typ == recordTypeHandshake {
216		c.sendHandshakeSeq++
217	}
218
219	if typ == recordTypeChangeCipherSpec {
220		err = c.out.changeCipherSpec(c.config)
221		if err != nil {
222			// Cannot call sendAlert directly,
223			// because we already hold c.out.Mutex.
224			c.tmp[0] = alertLevelError
225			c.tmp[1] = byte(err.(alert))
226			c.writeRecord(recordTypeAlert, c.tmp[0:2])
227			return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
228		}
229	}
230	return
231}
232
233func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
234	// Assemble a full handshake message.  For test purposes, this
235	// implementation assumes fragments arrive in order. It may
236	// need to be cleverer if we ever test BoringSSL's retransmit
237	// behavior.
238	for len(c.handMsg) < 4+c.handMsgLen {
239		// Get a new handshake record if the previous has been
240		// exhausted.
241		if c.hand.Len() == 0 {
242			if err := c.in.err; err != nil {
243				return nil, err
244			}
245			if err := c.readRecord(recordTypeHandshake); err != nil {
246				return nil, err
247			}
248		}
249
250		// Read the next fragment. It must fit entirely within
251		// the record.
252		if c.hand.Len() < 12 {
253			return nil, errors.New("dtls: bad handshake record")
254		}
255		header := c.hand.Next(12)
256		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
257		fragSeq := uint16(header[4])<<8 | uint16(header[5])
258		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
259		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
260
261		if c.hand.Len() < fragLen {
262			return nil, errors.New("dtls: fragment length too long")
263		}
264		fragment := c.hand.Next(fragLen)
265
266		// Check it's a fragment for the right message.
267		if fragSeq != c.recvHandshakeSeq {
268			return nil, errors.New("dtls: bad handshake sequence number")
269		}
270
271		// Check that the length is consistent.
272		if c.handMsg == nil {
273			c.handMsgLen = fragN
274			if c.handMsgLen > maxHandshake {
275				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
276			}
277			// Start with the TLS handshake header,
278			// without the DTLS bits.
279			c.handMsg = append([]byte{}, header[:4]...)
280		} else if fragN != c.handMsgLen {
281			return nil, errors.New("dtls: bad handshake length")
282		}
283
284		// Add the fragment to the pending message.
285		if 4+fragOff != len(c.handMsg) {
286			return nil, errors.New("dtls: bad fragment offset")
287		}
288		if fragOff+fragLen > c.handMsgLen {
289			return nil, errors.New("dtls: bad fragment length")
290		}
291		c.handMsg = append(c.handMsg, fragment...)
292	}
293	c.recvHandshakeSeq++
294	ret := c.handMsg
295	c.handMsg, c.handMsgLen = nil, 0
296	return ret, nil
297}
298
299// DTLSServer returns a new DTLS server side connection
300// using conn as the underlying transport.
301// The configuration config must be non-nil and must have
302// at least one certificate.
303func DTLSServer(conn net.Conn, config *Config) *Conn {
304	return &Conn{
305		config: config,
306		isDTLS: true,
307		conn:   conn,
308		in:     halfConn{isDTLS: true},
309		out:    halfConn{isDTLS: true},
310	}
311}
312
313// DTLSClient returns a new DTLS client side connection
314// using conn as the underlying transport.
315// The config cannot be nil: users must set either ServerHostname or
316// InsecureSkipVerify in the config.
317func DTLSClient(conn net.Conn, config *Config) *Conn {
318	return &Conn{
319		config:   config,
320		isClient: true,
321		isDTLS:   true,
322		conn:     conn,
323		in:       halfConn{isDTLS: true},
324		out:      halfConn{isDTLS: true},
325	}
326}
327