1// Copyright 2014 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5// DTLS implementation. 6// 7// NOTE: This is a not even a remotely production-quality DTLS 8// implementation. It is the bare minimum necessary to be able to 9// achieve coverage on BoringSSL's implementation. Of note is that 10// this implementation assumes the underlying net.PacketConn is not 11// only reliable but also ordered. BoringSSL will be expected to deal 12// with simulated loss, but there is no point in forcing the test 13// driver to. 14 15package runner 16 17import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "io" 22 "math/rand" 23 "net" 24) 25 26func versionToWire(vers uint16, isDTLS bool) uint16 { 27 if isDTLS { 28 return ^(vers - 0x0201) 29 } 30 return vers 31} 32 33func wireToVersion(vers uint16, isDTLS bool) uint16 { 34 if isDTLS { 35 return ^vers + 0x0201 36 } 37 return vers 38} 39 40func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { 41 recordHeaderLen := dtlsRecordHeaderLen 42 43 if c.rawInput == nil { 44 c.rawInput = c.in.newBlock() 45 } 46 b := c.rawInput 47 48 // Read a new packet only if the current one is empty. 49 if len(b.data) == 0 { 50 // Pick some absurdly large buffer size. 51 b.resize(maxCiphertext + recordHeaderLen) 52 n, err := c.conn.Read(c.rawInput.data) 53 if err != nil { 54 return 0, nil, err 55 } 56 if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength { 57 return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") 58 } 59 c.rawInput.resize(n) 60 } 61 62 // Read out one record. 63 // 64 // A real DTLS implementation should be tolerant of errors, 65 // but this is test code. We should not be tolerant of our 66 // peer sending garbage. 67 if len(b.data) < recordHeaderLen { 68 return 0, nil, errors.New("dtls: failed to read record header") 69 } 70 typ := recordType(b.data[0]) 71 vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS) 72 if c.haveVers { 73 if vers != c.vers { 74 c.sendAlert(alertProtocolVersion) 75 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers)) 76 } 77 } else { 78 if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { 79 c.sendAlert(alertProtocolVersion) 80 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect)) 81 } 82 } 83 epoch := b.data[3:5] 84 seq := b.data[5:11] 85 // For test purposes, require the sequence number be monotonically 86 // increasing, so c.in includes the minimum next sequence number. Gaps 87 // may occur if packets failed to be sent out. A real implementation 88 // would maintain a replay window and such. 89 if !bytes.Equal(epoch, c.in.seq[:2]) { 90 c.sendAlert(alertIllegalParameter) 91 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch")) 92 } 93 if bytes.Compare(seq, c.in.seq[2:]) < 0 { 94 c.sendAlert(alertIllegalParameter) 95 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number")) 96 } 97 copy(c.in.seq[2:], seq) 98 n := int(b.data[11])<<8 | int(b.data[12]) 99 if n > maxCiphertext || len(b.data) < recordHeaderLen+n { 100 c.sendAlert(alertRecordOverflow) 101 return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n)) 102 } 103 104 // Process message. 105 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) 106 ok, off, err := c.in.decrypt(b) 107 if !ok { 108 c.in.setErrorLocked(c.sendAlert(err)) 109 } 110 b.off = off 111 return typ, b, nil 112} 113 114func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte { 115 fragment := make([]byte, 0, 12+fragLen) 116 fragment = append(fragment, header...) 117 fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq)) 118 fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset)) 119 fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen)) 120 fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...) 121 return fragment 122} 123 124func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) { 125 if typ != recordTypeHandshake { 126 // Only handshake messages are fragmented. 127 return c.dtlsWriteRawRecord(typ, data) 128 } 129 130 maxLen := c.config.Bugs.MaxHandshakeRecordLength 131 if maxLen <= 0 { 132 maxLen = 1024 133 } 134 135 // Handshake messages have to be modified to include fragment 136 // offset and length and with the header replicated. Save the 137 // TLS header here. 138 // 139 // TODO(davidben): This assumes that data contains exactly one 140 // handshake message. This is incompatible with 141 // FragmentAcrossChangeCipherSpec. (Which is unfortunate 142 // because OpenSSL's DTLS implementation will probably accept 143 // such fragmentation and could do with a fix + tests.) 144 header := data[:4] 145 data = data[4:] 146 147 isFinished := header[0] == typeFinished 148 149 if c.config.Bugs.SendEmptyFragments { 150 fragment := c.makeFragment(header, data, 0, 0) 151 c.pendingFragments = append(c.pendingFragments, fragment) 152 } 153 154 firstRun := true 155 fragOffset := 0 156 for firstRun || fragOffset < len(data) { 157 firstRun = false 158 fragLen := len(data) - fragOffset 159 if fragLen > maxLen { 160 fragLen = maxLen 161 } 162 163 fragment := c.makeFragment(header, data, fragOffset, fragLen) 164 if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 { 165 fragment[0]++ 166 } 167 if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 { 168 fragment[3]++ 169 } 170 171 // Buffer the fragment for later. They will be sent (and 172 // reordered) on flush. 173 c.pendingFragments = append(c.pendingFragments, fragment) 174 if c.config.Bugs.ReorderHandshakeFragments { 175 // Don't duplicate Finished to avoid the peer 176 // interpreting it as a retransmit request. 177 if !isFinished { 178 c.pendingFragments = append(c.pendingFragments, fragment) 179 } 180 181 if fragLen > (maxLen+1)/2 { 182 // Overlap each fragment by half. 183 fragLen = (maxLen + 1) / 2 184 } 185 } 186 fragOffset += fragLen 187 n += fragLen 188 } 189 if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments { 190 fragment := c.makeFragment(header, data, 0, len(data)) 191 c.pendingFragments = append(c.pendingFragments, fragment) 192 } 193 194 // Increment the handshake sequence number for the next 195 // handshake message. 196 c.sendHandshakeSeq++ 197 return 198} 199 200func (c *Conn) dtlsFlushHandshake() error { 201 if !c.isDTLS { 202 return nil 203 } 204 205 // This is a test-only DTLS implementation, so there is no need to 206 // retain |c.pendingFragments| for a future retransmit. 207 var fragments [][]byte 208 fragments, c.pendingFragments = c.pendingFragments, fragments 209 210 if c.config.Bugs.ReorderHandshakeFragments { 211 perm := rand.New(rand.NewSource(0)).Perm(len(fragments)) 212 tmp := make([][]byte, len(fragments)) 213 for i := range tmp { 214 tmp[i] = fragments[perm[i]] 215 } 216 fragments = tmp 217 } 218 219 maxRecordLen := c.config.Bugs.PackHandshakeFragments 220 maxPacketLen := c.config.Bugs.PackHandshakeRecords 221 222 // Pack handshake fragments into records. 223 var records [][]byte 224 for _, fragment := range fragments { 225 if n := c.config.Bugs.SplitFragments; n > 0 { 226 if len(fragment) > n { 227 records = append(records, fragment[:n]) 228 records = append(records, fragment[n:]) 229 } else { 230 records = append(records, fragment) 231 } 232 } else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen { 233 records[i] = append(records[i], fragment...) 234 } else { 235 // The fragment will be appended to, so copy it. 236 records = append(records, append([]byte{}, fragment...)) 237 } 238 } 239 240 // Format them into packets. 241 var packets [][]byte 242 for _, record := range records { 243 b, err := c.dtlsSealRecord(recordTypeHandshake, record) 244 if err != nil { 245 return err 246 } 247 248 if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen { 249 packets[i] = append(packets[i], b.data...) 250 } else { 251 // The sealed record will be appended to and reused by 252 // |c.out|, so copy it. 253 packets = append(packets, append([]byte{}, b.data...)) 254 } 255 c.out.freeBlock(b) 256 } 257 258 // Send all the packets. 259 for _, packet := range packets { 260 if _, err := c.conn.Write(packet); err != nil { 261 return err 262 } 263 } 264 return nil 265} 266 267// dtlsSealRecord seals a record into a block from |c.out|'s pool. 268func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) { 269 recordHeaderLen := dtlsRecordHeaderLen 270 maxLen := c.config.Bugs.MaxHandshakeRecordLength 271 if maxLen <= 0 { 272 maxLen = 1024 273 } 274 275 b = c.out.newBlock() 276 277 explicitIVLen := 0 278 explicitIVIsSeq := false 279 280 if cbc, ok := c.out.cipher.(cbcMode); ok { 281 // Block cipher modes have an explicit IV. 282 explicitIVLen = cbc.BlockSize() 283 } else if aead, ok := c.out.cipher.(*tlsAead); ok { 284 if aead.explicitNonce { 285 explicitIVLen = 8 286 // The AES-GCM construction in TLS has an explicit nonce so that 287 // the nonce can be random. However, the nonce is only 8 bytes 288 // which is too small for a secure, random nonce. Therefore we 289 // use the sequence number as the nonce. 290 explicitIVIsSeq = true 291 } 292 } else if c.out.cipher != nil { 293 panic("Unknown cipher") 294 } 295 b.resize(recordHeaderLen + explicitIVLen + len(data)) 296 b.data[0] = byte(typ) 297 vers := c.vers 298 if vers == 0 { 299 // Some TLS servers fail if the record version is greater than 300 // TLS 1.0 for the initial ClientHello. 301 vers = VersionTLS10 302 } 303 vers = versionToWire(vers, c.isDTLS) 304 b.data[1] = byte(vers >> 8) 305 b.data[2] = byte(vers) 306 // DTLS records include an explicit sequence number. 307 copy(b.data[3:11], c.out.outSeq[0:]) 308 b.data[11] = byte(len(data) >> 8) 309 b.data[12] = byte(len(data)) 310 if explicitIVLen > 0 { 311 explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] 312 if explicitIVIsSeq { 313 copy(explicitIV, c.out.outSeq[:]) 314 } else { 315 if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { 316 return 317 } 318 } 319 } 320 copy(b.data[recordHeaderLen+explicitIVLen:], data) 321 c.out.encrypt(b, explicitIVLen) 322 return 323} 324 325func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) { 326 b, err := c.dtlsSealRecord(typ, data) 327 if err != nil { 328 return 329 } 330 331 _, err = c.conn.Write(b.data) 332 if err != nil { 333 return 334 } 335 n = len(data) 336 337 c.out.freeBlock(b) 338 339 if typ == recordTypeChangeCipherSpec { 340 err = c.out.changeCipherSpec(c.config) 341 if err != nil { 342 // Cannot call sendAlert directly, 343 // because we already hold c.out.Mutex. 344 c.tmp[0] = alertLevelError 345 c.tmp[1] = byte(err.(alert)) 346 c.writeRecord(recordTypeAlert, c.tmp[0:2]) 347 return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) 348 } 349 } 350 return 351} 352 353func (c *Conn) dtlsDoReadHandshake() ([]byte, error) { 354 // Assemble a full handshake message. For test purposes, this 355 // implementation assumes fragments arrive in order. It may 356 // need to be cleverer if we ever test BoringSSL's retransmit 357 // behavior. 358 for len(c.handMsg) < 4+c.handMsgLen { 359 // Get a new handshake record if the previous has been 360 // exhausted. 361 if c.hand.Len() == 0 { 362 if err := c.in.err; err != nil { 363 return nil, err 364 } 365 if err := c.readRecord(recordTypeHandshake); err != nil { 366 return nil, err 367 } 368 } 369 370 // Read the next fragment. It must fit entirely within 371 // the record. 372 if c.hand.Len() < 12 { 373 return nil, errors.New("dtls: bad handshake record") 374 } 375 header := c.hand.Next(12) 376 fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3]) 377 fragSeq := uint16(header[4])<<8 | uint16(header[5]) 378 fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8]) 379 fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11]) 380 381 if c.hand.Len() < fragLen { 382 return nil, errors.New("dtls: fragment length too long") 383 } 384 fragment := c.hand.Next(fragLen) 385 386 // Check it's a fragment for the right message. 387 if fragSeq != c.recvHandshakeSeq { 388 return nil, errors.New("dtls: bad handshake sequence number") 389 } 390 391 // Check that the length is consistent. 392 if c.handMsg == nil { 393 c.handMsgLen = fragN 394 if c.handMsgLen > maxHandshake { 395 return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) 396 } 397 // Start with the TLS handshake header, 398 // without the DTLS bits. 399 c.handMsg = append([]byte{}, header[:4]...) 400 } else if fragN != c.handMsgLen { 401 return nil, errors.New("dtls: bad handshake length") 402 } 403 404 // Add the fragment to the pending message. 405 if 4+fragOff != len(c.handMsg) { 406 return nil, errors.New("dtls: bad fragment offset") 407 } 408 if fragOff+fragLen > c.handMsgLen { 409 return nil, errors.New("dtls: bad fragment length") 410 } 411 c.handMsg = append(c.handMsg, fragment...) 412 } 413 c.recvHandshakeSeq++ 414 ret := c.handMsg 415 c.handMsg, c.handMsgLen = nil, 0 416 return ret, nil 417} 418 419// DTLSServer returns a new DTLS server side connection 420// using conn as the underlying transport. 421// The configuration config must be non-nil and must have 422// at least one certificate. 423func DTLSServer(conn net.Conn, config *Config) *Conn { 424 c := &Conn{config: config, isDTLS: true, conn: conn} 425 c.init() 426 return c 427} 428 429// DTLSClient returns a new DTLS client side connection 430// using conn as the underlying transport. 431// The config cannot be nil: users must set either ServerHostname or 432// InsecureSkipVerify in the config. 433func DTLSClient(conn net.Conn, config *Config) *Conn { 434 c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn} 435 c.init() 436 return c 437} 438