diff --git a/third_party/tlslite/tlslite/TLSConnection.py b/third_party/tlslite/tlslite/TLSConnection.py index f8811a9..e882e2c 100644 --- a/third_party/tlslite/tlslite/TLSConnection.py +++ b/third_party/tlslite/tlslite/TLSConnection.py @@ -611,6 +611,8 @@ class TLSConnection(TLSRecordLayer): settings.cipherImplementations) #Exchange ChangeCipherSpec and Finished messages + for result in self._getChangeCipherSpec(): + yield result for result in self._getFinished(): yield result for result in self._sendFinished(): @@ -920,6 +922,8 @@ class TLSConnection(TLSRecordLayer): #Exchange ChangeCipherSpec and Finished messages for result in self._sendFinished(): yield result + for result in self._getChangeCipherSpec(): + yield result for result in self._getFinished(): yield result @@ -1089,6 +1093,7 @@ class TLSConnection(TLSRecordLayer): clientCertChain = None serverCertChain = None #We may set certChain to this later postFinishedError = None + doingChannelID = False #Tentatively set version to most-desirable version, so if an error #occurs parsing the ClientHello, this is what we'll use for the @@ -1208,6 +1213,8 @@ class TLSConnection(TLSRecordLayer): serverHello.create(self.version, serverRandom, session.sessionID, session.cipherSuite, certificateType) + serverHello.channel_id = clientHello.channel_id + doingChannelID = clientHello.channel_id for result in self._sendMsg(serverHello): yield result @@ -1221,6 +1228,11 @@ class TLSConnection(TLSRecordLayer): #Exchange ChangeCipherSpec and Finished messages for result in self._sendFinished(): yield result + for result in self._getChangeCipherSpec(): + yield result + if doingChannelID: + for result in self._getEncryptedExtensions(): + yield result for result in self._getFinished(): yield result @@ -1399,8 +1411,12 @@ class TLSConnection(TLSRecordLayer): #Send ServerHello, Certificate[, CertificateRequest], #ServerHelloDone msgs = [] - msgs.append(ServerHello().create(self.version, serverRandom, - sessionID, cipherSuite, certificateType)) + serverHello = ServerHello().create( + self.version, serverRandom, + sessionID, cipherSuite, certificateType) + serverHello.channel_id = clientHello.channel_id + doingChannelID = clientHello.channel_id + msgs.append(serverHello) msgs.append(Certificate(certificateType).create(serverCertChain)) if reqCert and reqCAs: msgs.append(CertificateRequest().create([], reqCAs)) @@ -1528,6 +1544,11 @@ class TLSConnection(TLSRecordLayer): settings.cipherImplementations) #Exchange ChangeCipherSpec and Finished messages + for result in self._getChangeCipherSpec(): + yield result + if doingChannelID: + for result in self._getEncryptedExtensions(): + yield result for result in self._getFinished(): yield result diff --git a/third_party/tlslite/tlslite/TLSRecordLayer.py b/third_party/tlslite/tlslite/TLSRecordLayer.py index 1bbd09d..933b95a 100644 --- a/third_party/tlslite/tlslite/TLSRecordLayer.py +++ b/third_party/tlslite/tlslite/TLSRecordLayer.py @@ -714,6 +714,8 @@ class TLSRecordLayer: self.version).parse(p) elif subType == HandshakeType.finished: yield Finished(self.version).parse(p) + elif subType == HandshakeType.encrypted_extensions: + yield EncryptedExtensions().parse(p) else: raise AssertionError() @@ -1067,7 +1069,7 @@ class TLSRecordLayer: for result in self._sendMsg(finished): yield result - def _getFinished(self): + def _getChangeCipherSpec(self): #Get and check ChangeCipherSpec for result in self._getMsg(ContentType.change_cipher_spec): if result in (0,1): @@ -1082,6 +1084,15 @@ class TLSRecordLayer: #Switch to pending read state self._changeReadState() + def _getEncryptedExtensions(self): + for result in self._getMsg(ContentType.handshake, + HandshakeType.encrypted_extensions): + if result in (0,1): + yield result + encrypted_extensions = result + self.channel_id = encrypted_extensions.channel_id_key + + def _getFinished(self): #Calculate verification data verifyData = self._calcFinished(False) diff --git a/third_party/tlslite/tlslite/constants.py b/third_party/tlslite/tlslite/constants.py index 04302c0..e357dd0 100644 --- a/third_party/tlslite/tlslite/constants.py +++ b/third_party/tlslite/tlslite/constants.py @@ -22,6 +22,7 @@ class HandshakeType: certificate_verify = 15 client_key_exchange = 16 finished = 20 + encrypted_extensions = 203 class ContentType: change_cipher_spec = 20 @@ -30,6 +31,9 @@ class ContentType: application_data = 23 all = (20,21,22,23) +class ExtensionType: + channel_id = 30031 + class AlertLevel: warning = 1 fatal = 2 diff --git a/third_party/tlslite/tlslite/messages.py b/third_party/tlslite/tlslite/messages.py index dc6ed32..fa4d817 100644 --- a/third_party/tlslite/tlslite/messages.py +++ b/third_party/tlslite/tlslite/messages.py @@ -130,6 +130,7 @@ class ClientHello(HandshakeMsg): self.certificate_types = [CertificateType.x509] self.compression_methods = [] # a list of 8-bit values self.srp_username = None # a string + self.channel_id = False def create(self, version, random, session_id, cipher_suites, certificate_types=None, srp_username=None): @@ -174,6 +175,8 @@ class ClientHello(HandshakeMsg): self.srp_username = bytesToString(p.getVarBytes(1)) elif extType == 7: self.certificate_types = p.getVarList(1, 1) + elif extType == ExtensionType.channel_id: + self.channel_id = True else: p.getFixBytes(extLength) soFar += 4 + extLength @@ -220,6 +223,7 @@ class ServerHello(HandshakeMsg): self.cipher_suite = 0 self.certificate_type = CertificateType.x509 self.compression_method = 0 + self.channel_id = False def create(self, version, random, session_id, cipher_suite, certificate_type): @@ -266,6 +270,9 @@ class ServerHello(HandshakeMsg): CertificateType.x509: extLength += 5 + if self.channel_id: + extLength += 4 + if extLength != 0: w.add(extLength, 2) @@ -275,6 +282,10 @@ class ServerHello(HandshakeMsg): w.add(1, 2) w.add(self.certificate_type, 1) + if self.channel_id: + w.add(ExtensionType.channel_id, 2) + w.add(0, 2) + return HandshakeMsg.postWrite(self, w, trial) class Certificate(HandshakeMsg): @@ -567,6 +578,28 @@ class Finished(HandshakeMsg): w.addFixSeq(self.verify_data, 1) return HandshakeMsg.postWrite(self, w, trial) +class EncryptedExtensions(HandshakeMsg): + def __init__(self): + self.channel_id_key = None + self.channel_id_proof = None + + def parse(self, p): + p.startLengthCheck(3) + soFar = 0 + while soFar != p.lengthCheck: + extType = p.get(2) + extLength = p.get(2) + if extType == ExtensionType.channel_id: + if extLength != 32*4: + raise SyntaxError() + self.channel_id_key = p.getFixBytes(64) + self.channel_id_proof = p.getFixBytes(64) + else: + p.getFixBytes(extLength) + soFar += 4 + extLength + p.stopLengthCheck() + return self + class ApplicationData(Msg): def __init__(self): self.contentType = ContentType.application_data