1## This file is part of Scapy 2## Copyright (C) 2017 Maxence Tury 3## This program is published under a GPLv2 license 4 5""" 6TLS 1.3 key exchange logic. 7""" 8 9import math 10 11from scapy.config import conf, crypto_validator 12from scapy.error import log_runtime, warning 13from scapy.fields import * 14from scapy.packet import Packet, Raw, Padding 15from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA 16from scapy.layers.tls.session import _GenericTLSSessionInheritance 17from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField 18from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext 19from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip 20from scapy.layers.tls.crypto.groups import (_tls_named_ffdh_groups, 21 _tls_named_curves, _ffdh_groups, 22 _tls_named_groups) 23 24if conf.crypto_valid: 25 from cryptography.hazmat.backends import default_backend 26 from cryptography.hazmat.primitives.asymmetric import dh, ec 27if conf.crypto_valid_advanced: 28 from cryptography.hazmat.primitives.asymmetric import x25519 29 30 31class KeyShareEntry(Packet): 32 """ 33 When building from scratch, we create a DH private key, and when 34 dissecting, we create a DH public key. Default group is secp256r1. 35 """ 36 __slots__ = ["privkey", "pubkey"] 37 name = "Key Share Entry" 38 fields_desc = [ShortEnumField("group", None, _tls_named_groups), 39 FieldLenField("kxlen", None, length_of="key_exchange"), 40 StrLenField("key_exchange", "", 41 length_from=lambda pkt: pkt.kxlen) ] 42 43 def __init__(self, *args, **kargs): 44 self.privkey = None 45 self.pubkey = None 46 super(KeyShareEntry, self).__init__(*args, **kargs) 47 48 def do_build(self): 49 """ 50 We need this hack, else 'self' would be replaced by __iter__.next(). 51 """ 52 tmp = self.explicit 53 self.explicit = True 54 b = super(KeyShareEntry, self).do_build() 55 self.explicit = tmp 56 return b 57 58 @crypto_validator 59 def create_privkey(self): 60 """ 61 This is called by post_build() for key creation. 62 """ 63 if self.group in _tls_named_ffdh_groups: 64 params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0] 65 privkey = params.generate_private_key() 66 self.privkey = privkey 67 pubkey = privkey.public_key() 68 self.key_exchange = pubkey.public_numbers().y 69 elif self.group in _tls_named_curves: 70 if _tls_named_curves[self.group] == "x25519": 71 if conf.crypto_valid_advanced: 72 privkey = x25519.X25519PrivateKey.generate() 73 self.privkey = privkey 74 pubkey = privkey.public_key() 75 self.key_exchange = pubkey.public_bytes() 76 elif _tls_named_curves[self.group] != "x448": 77 curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]() 78 privkey = ec.generate_private_key(curve, default_backend()) 79 self.privkey = privkey 80 pubkey = privkey.public_key() 81 self.key_exchange = pubkey.public_numbers().encode_point() 82 83 def post_build(self, pkt, pay): 84 if self.group is None: 85 self.group = 23 # secp256r1 86 87 if not self.key_exchange: 88 try: 89 self.create_privkey() 90 except ImportError: 91 pass 92 93 if self.kxlen is None: 94 self.kxlen = len(self.key_exchange) 95 96 group = struct.pack("!H", self.group) 97 kxlen = struct.pack("!H", self.kxlen) 98 return group + kxlen + self.key_exchange + pay 99 100 @crypto_validator 101 def register_pubkey(self): 102 if self.group in _tls_named_ffdh_groups: 103 params = _ffdh_groups[_tls_named_ffdh_groups[self.group]][0] 104 pn = params.parameter_numbers() 105 public_numbers = dh.DHPublicNumbers(self.key_exchange, pn) 106 self.pubkey = public_numbers.public_key(default_backend()) 107 elif self.group in _tls_named_curves: 108 if _tls_named_curves[self.group] == "x25519": 109 if conf.crypto_valid_advanced: 110 import_point = x25519.X25519PublicKey.from_public_bytes 111 self.pubkey = import_point(self.key_exchange) 112 elif _tls_named_curves[self.group] != "x448": 113 curve = ec._CURVE_TYPES[_tls_named_curves[self.group]]() 114 import_point = ec.EllipticCurvePublicNumbers.from_encoded_point 115 public_numbers = import_point(curve, self.key_exchange) 116 self.pubkey = public_numbers.public_key(default_backend()) 117 118 def post_dissection(self, r): 119 try: 120 self.register_pubkey() 121 except ImportError: 122 pass 123 124 def extract_padding(self, s): 125 return "", s 126 127 128class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown): 129 name = "TLS Extension - Key Share (for ClientHello)" 130 fields_desc = [ShortEnumField("type", 0x28, _tls_ext), 131 ShortField("len", None), 132 FieldLenField("client_shares_len", None, 133 length_of="client_shares"), 134 PacketListField("client_shares", [], KeyShareEntry, 135 length_from=lambda pkt: pkt.client_shares_len) ] 136 137 def post_build(self, pkt, pay): 138 if not self.tls_session.frozen: 139 privshares = self.tls_session.tls13_client_privshares 140 for kse in self.client_shares: 141 if kse.privkey: 142 if _tls_named_curves[kse.group] in privshares: 143 pkt_info = pkt.firstlayer().summary() 144 log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info) 145 break 146 privshares[_tls_named_groups[kse.group]] = kse.privkey 147 return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay) 148 149 def post_dissection(self, r): 150 if not self.tls_session.frozen: 151 for kse in self.client_shares: 152 if kse.pubkey: 153 pubshares = self.tls_session.tls13_client_pubshares 154 if _tls_named_curves[kse.group] in pubshares: 155 pkt_info = r.firstlayer().summary() 156 log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info) 157 break 158 pubshares[_tls_named_curves[kse.group]] = kse.pubkey 159 return super(TLS_Ext_KeyShare_CH, self).post_dissection(r) 160 161 162class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown): 163 name = "TLS Extension - Key Share (for HelloRetryRequest)" 164 fields_desc = [ShortEnumField("type", 0x28, _tls_ext), 165 ShortField("len", None), 166 ShortEnumField("selected_group", None, _tls_named_groups) ] 167 168 169class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown): 170 name = "TLS Extension - Key Share (for ServerHello)" 171 fields_desc = [ShortEnumField("type", 0x28, _tls_ext), 172 ShortField("len", None), 173 PacketField("server_share", None, KeyShareEntry) ] 174 175 def post_build(self, pkt, pay): 176 if not self.tls_session.frozen and self.server_share.privkey: 177 # if there is a privkey, we assume the crypto library is ok 178 privshare = self.tls_session.tls13_server_privshare 179 if len(privshare) > 0: 180 pkt_info = pkt.firstlayer().summary() 181 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info) 182 group_name = _tls_named_groups[self.server_share.group] 183 privshare[group_name] = self.server_share.privkey 184 185 if group_name in self.tls_session.tls13_client_pubshares: 186 privkey = self.server_share.privkey 187 pubkey = self.tls_session.tls13_client_pubshares[group_name] 188 if group_name in six.itervalues(_tls_named_ffdh_groups): 189 pms = privkey.exchange(pubkey) 190 elif group_name in six.itervalues(_tls_named_curves): 191 if group_name == "x25519": 192 pms = privkey.exchange(pubkey) 193 else: 194 pms = privkey.exchange(ec.ECDH(), pubkey) 195 self.tls_session.tls13_dhe_secret = pms 196 return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay) 197 198 def post_dissection(self, r): 199 if not self.tls_session.frozen and self.server_share.pubkey: 200 # if there is a pubkey, we assume the crypto library is ok 201 pubshare = self.tls_session.tls13_server_pubshare 202 if len(pubshare) > 0: 203 pkt_info = r.firstlayer().summary() 204 log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info) 205 group_name = _tls_named_groups[self.server_share.group] 206 pubshare[group_name] = self.server_share.pubkey 207 208 if group_name in self.tls_session.tls13_client_privshares: 209 pubkey = self.server_share.pubkey 210 privkey = self.tls_session.tls13_client_privshares[group_name] 211 if group_name in six.itervalues(_tls_named_ffdh_groups): 212 pms = privkey.exchange(pubkey) 213 elif group_name in six.itervalues(_tls_named_curves): 214 if group_name == "x25519": 215 pms = privkey.exchange(pubkey) 216 else: 217 pms = privkey.exchange(ec.ECDH(), pubkey) 218 self.tls_session.tls13_dhe_secret = pms 219 return super(TLS_Ext_KeyShare_SH, self).post_dissection(r) 220 221 222_tls_ext_keyshare_cls = { 1: TLS_Ext_KeyShare_CH, 223 2: TLS_Ext_KeyShare_SH, 224 6: TLS_Ext_KeyShare_HRR } 225 226 227class Ticket(Packet): 228 name = "Recommended Ticket Construction (from RFC 5077)" 229 fields_desc = [ StrFixedLenField("key_name", None, 16), 230 StrFixedLenField("iv", None, 16), 231 FieldLenField("encstatelen", None, length_of="encstate"), 232 StrLenField("encstate", "", 233 length_from=lambda pkt: pkt.encstatelen), 234 StrFixedLenField("mac", None, 32) ] 235 236class TicketField(PacketField): 237 __slots__ = ["length_from"] 238 def __init__(self, name, default, length_from=None, **kargs): 239 self.length_from = length_from 240 PacketField.__init__(self, name, default, Ticket, **kargs) 241 242 def m2i(self, pkt, m): 243 l = self.length_from(pkt) 244 tbd, rem = m[:l], m[l:] 245 return self.cls(tbd)/Padding(rem) 246 247class PSKIdentity(Packet): 248 name = "PSK Identity" 249 fields_desc = [FieldLenField("identity_len", None, 250 length_of="identity"), 251 TicketField("identity", "", 252 length_from=lambda pkt: pkt.identity_len), 253 IntField("obfuscated_ticket_age", 0) ] 254 255class PSKBinderEntry(Packet): 256 name = "PSK Binder Entry" 257 fields_desc = [FieldLenField("binder_len", None, fmt="B", 258 length_of="binder"), 259 StrLenField("binder", "", 260 length_from=lambda pkt: pkt.binder_len) ] 261 262class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown): 263 #XXX define post_build and post_dissection methods 264 name = "TLS Extension - Pre Shared Key (for ClientHello)" 265 fields_desc = [ShortEnumField("type", 0x28, _tls_ext), 266 ShortField("len", None), 267 FieldLenField("identities_len", None, 268 length_of="identities"), 269 PacketListField("identities", [], PSKIdentity, 270 length_from=lambda pkt: pkt.identities_len), 271 FieldLenField("binders_len", None, 272 length_of="binders"), 273 PacketListField("binders", [], PSKBinderEntry, 274 length_from=lambda pkt: pkt.binders_len) ] 275 276 277class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown): 278 name = "TLS Extension - Pre Shared Key (for ServerHello)" 279 fields_desc = [ShortEnumField("type", 0x29, _tls_ext), 280 ShortField("len", None), 281 ShortField("selected_identity", None) ] 282 283 284_tls_ext_presharedkey_cls = { 1: TLS_Ext_PreSharedKey_CH, 285 2: TLS_Ext_PreSharedKey_SH } 286 287