1// Copyright 2012 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package runner 6 7import ( 8 "crypto/aes" 9 "crypto/cipher" 10 "crypto/hmac" 11 "crypto/sha256" 12 "crypto/subtle" 13 "encoding/binary" 14 "errors" 15 "io" 16 "time" 17) 18 19// sessionState contains the information that is serialized into a session 20// ticket in order to later resume a connection. 21type sessionState struct { 22 vers uint16 23 cipherSuite uint16 24 masterSecret []byte 25 handshakeHash []byte 26 certificates [][]byte 27 extendedMasterSecret bool 28 earlyALPN []byte 29 ticketCreationTime time.Time 30 ticketExpiration time.Time 31 ticketFlags uint32 32 ticketAgeAdd uint32 33} 34 35func (s *sessionState) marshal() []byte { 36 msg := newByteBuilder() 37 msg.addU16(s.vers) 38 msg.addU16(s.cipherSuite) 39 masterSecret := msg.addU16LengthPrefixed() 40 masterSecret.addBytes(s.masterSecret) 41 handshakeHash := msg.addU16LengthPrefixed() 42 handshakeHash.addBytes(s.handshakeHash) 43 msg.addU16(uint16(len(s.certificates))) 44 for _, cert := range s.certificates { 45 certMsg := msg.addU32LengthPrefixed() 46 certMsg.addBytes(cert) 47 } 48 49 if s.extendedMasterSecret { 50 msg.addU8(1) 51 } else { 52 msg.addU8(0) 53 } 54 55 if s.vers >= VersionTLS13 { 56 msg.addU64(uint64(s.ticketCreationTime.UnixNano())) 57 msg.addU64(uint64(s.ticketExpiration.UnixNano())) 58 msg.addU32(s.ticketFlags) 59 msg.addU32(s.ticketAgeAdd) 60 } 61 62 earlyALPN := msg.addU16LengthPrefixed() 63 earlyALPN.addBytes(s.earlyALPN) 64 65 return msg.finish() 66} 67 68func (s *sessionState) unmarshal(data []byte) bool { 69 if len(data) < 8 { 70 return false 71 } 72 73 s.vers = uint16(data[0])<<8 | uint16(data[1]) 74 s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) 75 masterSecretLen := int(data[4])<<8 | int(data[5]) 76 data = data[6:] 77 if len(data) < masterSecretLen { 78 return false 79 } 80 81 s.masterSecret = data[:masterSecretLen] 82 data = data[masterSecretLen:] 83 84 if len(data) < 2 { 85 return false 86 } 87 88 handshakeHashLen := int(data[0])<<8 | int(data[1]) 89 data = data[2:] 90 if len(data) < handshakeHashLen { 91 return false 92 } 93 94 s.handshakeHash = data[:handshakeHashLen] 95 data = data[handshakeHashLen:] 96 97 if len(data) < 2 { 98 return false 99 } 100 101 numCerts := int(data[0])<<8 | int(data[1]) 102 data = data[2:] 103 104 s.certificates = make([][]byte, numCerts) 105 for i := range s.certificates { 106 if len(data) < 4 { 107 return false 108 } 109 certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 110 data = data[4:] 111 if certLen < 0 { 112 return false 113 } 114 if len(data) < certLen { 115 return false 116 } 117 s.certificates[i] = data[:certLen] 118 data = data[certLen:] 119 } 120 121 if len(data) < 1 { 122 return false 123 } 124 125 s.extendedMasterSecret = false 126 if data[0] == 1 { 127 s.extendedMasterSecret = true 128 } 129 data = data[1:] 130 131 if s.vers >= VersionTLS13 { 132 if len(data) < 24 { 133 return false 134 } 135 s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data))) 136 data = data[8:] 137 s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data))) 138 data = data[8:] 139 s.ticketFlags = binary.BigEndian.Uint32(data) 140 data = data[4:] 141 s.ticketAgeAdd = binary.BigEndian.Uint32(data) 142 data = data[4:] 143 } 144 145 earlyALPNLen := int(data[0])<<8 | int(data[1]) 146 data = data[2:] 147 if len(data) < earlyALPNLen { 148 return false 149 } 150 s.earlyALPN = data[:earlyALPNLen] 151 data = data[earlyALPNLen:] 152 153 if len(data) > 0 { 154 return false 155 } 156 157 return true 158} 159 160func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { 161 serialized := state.marshal() 162 encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size) 163 iv := encrypted[:aes.BlockSize] 164 macBytes := encrypted[len(encrypted)-sha256.Size:] 165 166 if _, err := io.ReadFull(c.config.rand(), iv); err != nil { 167 return nil, err 168 } 169 block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) 170 if err != nil { 171 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) 172 } 173 cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized) 174 175 mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) 176 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 177 mac.Sum(macBytes[:0]) 178 179 return encrypted, nil 180} 181 182func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) { 183 if len(encrypted) < aes.BlockSize+sha256.Size { 184 return nil, false 185 } 186 187 iv := encrypted[:aes.BlockSize] 188 macBytes := encrypted[len(encrypted)-sha256.Size:] 189 190 mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) 191 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 192 expected := mac.Sum(nil) 193 194 if subtle.ConstantTimeCompare(macBytes, expected) != 1 { 195 return nil, false 196 } 197 198 block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) 199 if err != nil { 200 return nil, false 201 } 202 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size] 203 plaintext := make([]byte, len(ciphertext)) 204 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) 205 206 state := new(sessionState) 207 ok := state.unmarshal(plaintext) 208 return state, ok 209} 210