• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) 2017 Maxence Tury
5#               2019 Romain Perez
6
7"""
8TLS 1.3 key exchange logic.
9"""
10
11import struct
12
13from scapy.config import conf, crypto_validator
14from scapy.error import log_runtime
15from scapy.fields import (
16    FieldLenField,
17    IntField,
18    PacketField,
19    PacketLenField,
20    PacketListField,
21    ShortEnumField,
22    ShortField,
23    StrFixedLenField,
24    StrLenField,
25    XStrLenField,
26)
27from scapy.packet import Packet
28from scapy.layers.tls.extensions import TLS_Ext_Unknown, _tls_ext
29from scapy.layers.tls.cert import PrivKeyECDSA, PrivKeyRSA, PrivKeyEdDSA
30from scapy.layers.tls.crypto.groups import (
31    _tls_named_curves,
32    _tls_named_ffdh_groups,
33    _tls_named_groups,
34    _tls_named_groups_generate,
35    _tls_named_groups_import,
36    _tls_named_groups_pubbytes,
37)
38
39if conf.crypto_valid:
40    from cryptography.hazmat.primitives.asymmetric import ec
41if conf.crypto_valid_advanced:
42    from cryptography.hazmat.primitives.asymmetric import ed25519
43    from cryptography.hazmat.primitives.asymmetric import ed448
44
45
46class KeyShareEntry(Packet):
47    """
48    When building from scratch, we create a DH private key, and when
49    dissecting, we create a DH public key. Default group is secp256r1.
50    """
51    __slots__ = ["privkey", "pubkey"]
52    name = "Key Share Entry"
53    fields_desc = [ShortEnumField("group", None, _tls_named_groups),
54                   FieldLenField("kxlen", None, length_of="key_exchange"),
55                   XStrLenField("key_exchange", "",
56                                length_from=lambda pkt: pkt.kxlen)]
57
58    def __init__(self, *args, **kargs):
59        self.privkey = None
60        self.pubkey = None
61        super(KeyShareEntry, self).__init__(*args, **kargs)
62
63    def do_build(self):
64        """
65        We need this hack, else 'self' would be replaced by __iter__.next().
66        """
67        tmp = self.explicit
68        self.explicit = True
69        b = super(KeyShareEntry, self).do_build()
70        self.explicit = tmp
71        return b
72
73    @crypto_validator
74    def create_privkey(self):
75        """
76        This is called by post_build() for key creation.
77        """
78        self.privkey = _tls_named_groups_generate(self.group)
79        self.key_exchange = _tls_named_groups_pubbytes(self.privkey)
80
81    def post_build(self, pkt, pay):
82        if self.group is None:
83            self.group = 23     # secp256r1
84
85        if not self.key_exchange:
86            try:
87                self.create_privkey()
88            except ImportError:
89                pass
90
91        if self.kxlen is None:
92            self.kxlen = len(self.key_exchange)
93
94        group = struct.pack("!H", self.group)
95        kxlen = struct.pack("!H", self.kxlen)
96        return group + kxlen + self.key_exchange + pay
97
98    @crypto_validator
99    def register_pubkey(self):
100        self.pubkey = _tls_named_groups_import(
101            self.group,
102            self.key_exchange
103        )
104
105    def post_dissection(self, r):
106        try:
107            self.register_pubkey()
108        except ImportError:
109            pass
110
111    def extract_padding(self, s):
112        return "", s
113
114
115class TLS_Ext_KeyShare_CH(TLS_Ext_Unknown):
116    name = "TLS Extension - Key Share (for ClientHello)"
117    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
118                   ShortField("len", None),
119                   FieldLenField("client_shares_len", None,
120                                 length_of="client_shares"),
121                   PacketListField("client_shares", [], KeyShareEntry,
122                                   length_from=lambda pkt: pkt.client_shares_len)]  # noqa: E501
123
124    def post_build(self, pkt, pay):
125        if not self.tls_session.frozen:
126            privshares = self.tls_session.tls13_client_privshares
127            for kse in self.client_shares:
128                if kse.privkey:
129                    if _tls_named_groups[kse.group] in privshares:
130                        pkt_info = pkt.firstlayer().summary()
131                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)  # noqa: E501
132                        break
133                    privshares[_tls_named_groups[kse.group]] = kse.privkey
134        return super(TLS_Ext_KeyShare_CH, self).post_build(pkt, pay)
135
136    def post_dissection(self, r):
137        if not self.tls_session.frozen:
138            for kse in self.client_shares:
139                if kse.pubkey:
140                    pubshares = self.tls_session.tls13_client_pubshares
141                    if _tls_named_groups[kse.group] in pubshares:
142                        pkt_info = r.firstlayer().summary()
143                        log_runtime.info("TLS: group %s used twice in the same ClientHello [%s]", kse.group, pkt_info)  # noqa: E501
144                        break
145                    pubshares[_tls_named_groups[kse.group]] = kse.pubkey
146        return super(TLS_Ext_KeyShare_CH, self).post_dissection(r)
147
148
149class TLS_Ext_KeyShare_HRR(TLS_Ext_Unknown):
150    name = "TLS Extension - Key Share (for HelloRetryRequest)"
151    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
152                   ShortField("len", None),
153                   ShortEnumField("selected_group", None, _tls_named_groups)]
154
155
156class TLS_Ext_KeyShare_SH(TLS_Ext_Unknown):
157    name = "TLS Extension - Key Share (for ServerHello)"
158    fields_desc = [ShortEnumField("type", 0x33, _tls_ext),
159                   ShortField("len", None),
160                   PacketField("server_share", None, KeyShareEntry)]
161
162    def post_build(self, pkt, pay):
163        if not self.tls_session.frozen and self.server_share.privkey:
164            # if there is a privkey, we assume the crypto library is ok
165            privshare = self.tls_session.tls13_server_privshare
166            if len(privshare) > 0:
167                pkt_info = pkt.firstlayer().summary()
168                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)  # noqa: E501
169            group_name = _tls_named_groups[self.server_share.group]
170            privshare[group_name] = self.server_share.privkey
171
172            if group_name in self.tls_session.tls13_client_pubshares:
173                privkey = self.server_share.privkey
174                pubkey = self.tls_session.tls13_client_pubshares[group_name]
175                if group_name in _tls_named_ffdh_groups.values():
176                    pms = privkey.exchange(pubkey)
177                elif group_name in _tls_named_curves.values():
178                    if group_name in ["x25519", "x448"]:
179                        pms = privkey.exchange(pubkey)
180                    else:
181                        pms = privkey.exchange(ec.ECDH(), pubkey)
182                self.tls_session.tls13_dhe_secret = pms
183                self.tls_session.kx_group = group_name
184        return super(TLS_Ext_KeyShare_SH, self).post_build(pkt, pay)
185
186    def post_dissection(self, r):
187        if not self.tls_session.frozen and self.server_share.pubkey:
188            # if there is a pubkey, we assume the crypto library is ok
189            pubshare = self.tls_session.tls13_server_pubshare
190            if pubshare:
191                pkt_info = r.firstlayer().summary()
192                log_runtime.info("TLS: overwriting previous server key share [%s]", pkt_info)  # noqa: E501
193            group_name = _tls_named_groups[self.server_share.group]
194            pubshare[group_name] = self.server_share.pubkey
195
196            if group_name in self.tls_session.tls13_client_privshares:
197                pubkey = self.server_share.pubkey
198                privkey = self.tls_session.tls13_client_privshares[group_name]
199                if group_name in _tls_named_ffdh_groups.values():
200                    pms = privkey.exchange(pubkey)
201                elif group_name in _tls_named_curves.values():
202                    if group_name in ["x25519", "x448"]:
203                        pms = privkey.exchange(pubkey)
204                    else:
205                        pms = privkey.exchange(ec.ECDH(), pubkey)
206                self.tls_session.tls13_dhe_secret = pms
207                self.tls_session.kx_group = group_name
208            elif group_name in self.tls_session.tls13_server_privshare:
209                pubkey = self.tls_session.tls13_client_pubshares[group_name]
210                privkey = self.tls_session.tls13_server_privshare[group_name]
211                if group_name in _tls_named_ffdh_groups.values():
212                    pms = privkey.exchange(pubkey)
213                elif group_name in _tls_named_curves.values():
214                    if group_name in ["x25519", "x448"]:
215                        pms = privkey.exchange(pubkey)
216                    else:
217                        pms = privkey.exchange(ec.ECDH(), pubkey)
218                self.tls_session.tls13_dhe_secret = pms
219                self.tls_session.kx_group = group_name
220        return super(TLS_Ext_KeyShare_SH, self).post_dissection(r)
221
222
223_tls_ext_keyshare_cls = {1: TLS_Ext_KeyShare_CH,
224                         2: TLS_Ext_KeyShare_SH}
225
226_tls_ext_keyshare_hrr_cls = {2: TLS_Ext_KeyShare_HRR}
227
228
229class Ticket(Packet):
230    name = "Recommended Ticket Construction (from RFC 5077)"
231    fields_desc = [StrFixedLenField("key_name", None, 16),
232                   StrFixedLenField("iv", None, 16),
233                   FieldLenField("encstatelen", None, length_of="encstate"),
234                   StrLenField("encstate", "",
235                               length_from=lambda pkt: pkt.encstatelen),
236                   StrFixedLenField("mac", None, 32)]
237
238
239class TicketField(PacketLenField):
240    def m2i(self, pkt, m):
241        if len(m) < 64:
242            # Minimum ticket size is 64 bytes
243            return conf.raw_layer(m)
244        return self.cls(m)
245
246
247class PSKIdentity(Packet):
248    name = "PSK Identity"
249    fields_desc = [FieldLenField("identity_len", None,
250                                 length_of="identity"),
251                   TicketField("identity", "", Ticket,
252                               length_from=lambda pkt: pkt.identity_len),
253                   IntField("obfuscated_ticket_age", 0)]
254
255    def default_payload_class(self, payload):
256        return conf.padding_layer
257
258
259class PSKBinderEntry(Packet):
260    name = "PSK Binder Entry"
261    fields_desc = [FieldLenField("binder_len", None, fmt="B",
262                                 length_of="binder"),
263                   StrLenField("binder", "",
264                               length_from=lambda pkt: pkt.binder_len)]
265
266    def default_payload_class(self, payload):
267        return conf.padding_layer
268
269
270class TLS_Ext_PreSharedKey_CH(TLS_Ext_Unknown):
271    # XXX define post_build and post_dissection methods
272    name = "TLS Extension - Pre Shared Key (for ClientHello)"
273    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
274                   ShortField("len", None),
275                   FieldLenField("identities_len", None,
276                                 length_of="identities"),
277                   PacketListField("identities", [], PSKIdentity,
278                                   length_from=lambda pkt: pkt.identities_len),
279                   FieldLenField("binders_len", None,
280                                 length_of="binders"),
281                   PacketListField("binders", [], PSKBinderEntry,
282                                   length_from=lambda pkt: pkt.binders_len)]
283
284
285class TLS_Ext_PreSharedKey_SH(TLS_Ext_Unknown):
286    name = "TLS Extension - Pre Shared Key (for ServerHello)"
287    fields_desc = [ShortEnumField("type", 0x29, _tls_ext),
288                   ShortField("len", None),
289                   ShortField("selected_identity", None)]
290
291
292_tls_ext_presharedkey_cls = {1: TLS_Ext_PreSharedKey_CH,
293                             2: TLS_Ext_PreSharedKey_SH}
294
295
296# Util to find usable signature algorithms
297
298# TLS 1.3 SignatureScheme is a subset of _tls_hash_sig
299_tls13_usable_certificate_verify_algs = [
300    # ECDSA algorithms
301    0x0403, 0x0503, 0x0603,
302    # RSASSA-PSS algorithms with public key OID rsaEncryption
303    0x0804, 0x0805, 0x0806,
304    # EdDSA algorithms
305    0x0807, 0x0808,
306]
307
308_tls13_usable_certificate_signature_algs = [
309    # RSASSA-PKCS1-v1_5 algorithms
310    0x0401, 0x0501, 0x0601,
311    # ECDSA algorithms
312    0x0403, 0x0503, 0x0603,
313    # EdDSA algorithms
314    0x0807, 0x0808,
315    # RSASSA-PSS algorithms with public key OID RSASSA-PSS
316    0x0809, 0x080a, 0x080b,
317    # Legacy algorithms
318    0x0201, 0x0203,
319]
320
321
322def get_usable_tls13_sigalgs(li, key, location="certificateverify"):
323    """
324    From a list of proposed signature algorithms, this function returns a list of
325    usable signature algorithms.
326    The order of the signature algorithms in the list returned by the
327    function matches the one of the proposal.
328    """
329    from scapy.layers.tls.keyexchange import _tls_hash_sig
330    res = []
331    if isinstance(key, PrivKeyRSA):
332        kx = "rsa"
333    elif isinstance(key, PrivKeyECDSA):
334        kx = "ecdsa"
335    elif isinstance(key, PrivKeyEdDSA):
336        if isinstance(key.pubkey, ed25519.Ed25519PublicKey):
337            kx = "ed25519"
338        elif isinstance(key.pubkey, ed448.Ed448PublicKey):
339            kx = "ed448"
340        else:
341            kx = "unknown"
342    else:
343        return res
344    if location == "certificateverify":
345        algs = _tls13_usable_certificate_verify_algs
346    elif location == "certificatesignature":
347        algs = _tls13_usable_certificate_signature_algs
348    else:
349        return res
350    for c in li:
351        if c in algs:
352            sigalg = _tls_hash_sig[c]
353            if "+" in sigalg:
354                _, sig = sigalg.split('+')
355            else:
356                sig = sigalg
357            if kx in sig:
358                res.append(c)
359    return res
360