• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
5#               2015, 2016, 2017 Maxence Tury
6#               2019 Romain Perez
7
8"""
9TLS key exchange logic.
10"""
11
12import math
13import struct
14
15from scapy.config import conf, crypto_validator
16from scapy.error import warning
17from scapy.fields import ByteEnumField, ByteField, EnumField, FieldLenField, \
18    FieldListField, PacketField, ShortEnumField, ShortField, \
19    StrFixedLenField, StrLenField
20from scapy.compat import orb
21from scapy.packet import Packet, Raw, Padding
22from scapy.layers.tls.cert import PubKeyRSA, PrivKeyRSA
23from scapy.layers.tls.session import _GenericTLSSessionInheritance
24from scapy.layers.tls.basefields import _tls_version, _TLSClientVersionField
25from scapy.layers.tls.crypto.pkcs1 import pkcs_i2osp, pkcs_os2ip
26from scapy.layers.tls.crypto.groups import (
27    _ffdh_groups,
28    _tls_named_curves,
29    _tls_named_groups_generate,
30    _tls_named_groups_import,
31    _tls_named_groups_pubbytes,
32)
33
34
35if conf.crypto_valid:
36    from cryptography.hazmat.backends import default_backend
37    from cryptography.hazmat.primitives.asymmetric import dh, ec
38    from cryptography.hazmat.primitives import serialization
39if conf.crypto_valid_advanced:
40    from cryptography.hazmat.primitives.asymmetric import x25519
41    from cryptography.hazmat.primitives.asymmetric import x448
42
43
44###############################################################################
45#   Common Fields                                                             #
46###############################################################################
47
48_tls_hash_sig = {0x0000: "none+anon", 0x0001: "none+rsa",
49                 0x0002: "none+dsa", 0x0003: "none+ecdsa",
50                 0x0100: "md5+anon", 0x0101: "md5+rsa",
51                 0x0102: "md5+dsa", 0x0103: "md5+ecdsa",
52                 0x0200: "sha1+anon", 0x0201: "sha1+rsa",
53                 0x0202: "sha1+dsa", 0x0203: "sha1+ecdsa",
54                 0x0300: "sha224+anon", 0x0301: "sha224+rsa",
55                 0x0302: "sha224+dsa", 0x0303: "sha224+ecdsa",
56                 0x0400: "sha256+anon", 0x0401: "sha256+rsa",
57                 0x0402: "sha256+dsa", 0x0403: "sha256+ecdsa",
58                 0x0500: "sha384+anon", 0x0501: "sha384+rsa",
59                 0x0502: "sha384+dsa", 0x0503: "sha384+ecdsa",
60                 0x0600: "sha512+anon", 0x0601: "sha512+rsa",
61                 0x0602: "sha512+dsa", 0x0603: "sha512+ecdsa",
62                 0x0804: "sha256+rsaepss", 0x0805: "sha384+rsaepss",
63                 0x0806: "sha512+rsaepss", 0x0807: "ed25519",
64                 0x0808: "ed448", 0x0809: "sha256+rsapss",
65                 0x080a: "sha384+rsapss", 0x080b: "sha512+rsapss"}
66
67
68def phantom_mode(pkt):
69    """
70    We expect this. If tls_version is not set, this means we did not process
71    any complete ClientHello, so we're most probably reading/building a
72    signature_algorithms extension, hence we cannot be in phantom_mode.
73    However, if the tls_version has been set, we test for TLS 1.2.
74    """
75    if not pkt.tls_session:
76        return False
77    if not pkt.tls_session.tls_version:
78        return False
79    return pkt.tls_session.tls_version < 0x0303
80
81
82def phantom_decorate(f, get_or_add):
83    """
84    Decorator for version-dependent fields.
85    If get_or_add is True (means get), we return s, self.phantom_value.
86    If it is False (means add), we return s.
87    """
88    def wrapper(*args):
89        self, pkt, s = args[:3]
90        if phantom_mode(pkt):
91            if get_or_add:
92                return s, self.phantom_value
93            return s
94        return f(*args)
95    return wrapper
96
97
98class SigAndHashAlgField(EnumField):
99    """Used in _TLSSignature."""
100    phantom_value = None
101    getfield = phantom_decorate(EnumField.getfield, True)
102    addfield = phantom_decorate(EnumField.addfield, False)
103
104
105class SigAndHashAlgsLenField(FieldLenField):
106    """Used in TLS_Ext_SignatureAlgorithms and TLSCertificateResquest."""
107    phantom_value = 0
108    getfield = phantom_decorate(FieldLenField.getfield, True)
109    addfield = phantom_decorate(FieldLenField.addfield, False)
110
111
112class SigAndHashAlgsField(FieldListField):
113    """Used in TLS_Ext_SignatureAlgorithms and TLSCertificateResquest."""
114    phantom_value = []
115    getfield = phantom_decorate(FieldListField.getfield, True)
116    addfield = phantom_decorate(FieldListField.addfield, False)
117
118
119class SigLenField(FieldLenField):
120    """There is a trick for SSLv2, which uses implicit lengths..."""
121
122    def getfield(self, pkt, s):
123        v = pkt.tls_session.tls_version
124        if v and v < 0x0300:
125            return s, None
126        return super(SigLenField, self).getfield(pkt, s)
127
128    def addfield(self, pkt, s, val):
129        """With SSLv2 you will never be able to add a sig_len."""
130        v = pkt.tls_session.tls_version
131        if v and v < 0x0300:
132            return s
133        return super(SigLenField, self).addfield(pkt, s, val)
134
135
136class SigValField(StrLenField):
137    """There is a trick for SSLv2, which uses implicit lengths..."""
138
139    def getfield(self, pkt, m):
140        s = pkt.tls_session
141        if s.tls_version and s.tls_version < 0x0300:
142            if len(s.client_certs) > 0:
143                sig_len = s.client_certs[0].pubKey.pubkey.key_size // 8
144            else:
145                warning("No client certificate provided. "
146                        "We're making a wild guess about the signature size.")
147                sig_len = 256
148            return m[sig_len:], self.m2i(pkt, m[:sig_len])
149        return super(SigValField, self).getfield(pkt, m)
150
151
152class _TLSSignature(_GenericTLSSessionInheritance):
153    """
154    Prior to TLS 1.2, digitally-signed structure implicitly used the
155    concatenation of a MD5 hash and a SHA-1 hash.
156    Then TLS 1.2 introduced explicit SignatureAndHashAlgorithms,
157    i.e. couples of (hash_alg, sig_alg). See RFC 5246, section 7.4.1.4.1.
158
159    By default, the _TLSSignature implements the TLS 1.2 scheme,
160    but if it is provided a TLS context with a tls_version < 0x0303
161    at initialization, it will fall back to the implicit signature.
162    Even more, the 'sig_len' field won't be used with SSLv2.
163    """
164    name = "TLS Digital Signature"
165    fields_desc = [SigAndHashAlgField("sig_alg", None, _tls_hash_sig),
166                   SigLenField("sig_len", None, fmt="!H",
167                               length_of="sig_val"),
168                   SigValField("sig_val", None,
169                               length_from=lambda pkt: pkt.sig_len)]
170
171    def __init__(self, *args, **kargs):
172        super(_TLSSignature, self).__init__(*args, **kargs)
173        if "sig_alg" not in kargs:
174            # Default sig_alg
175            self.sig_alg = 0x0804
176            if self.tls_session and self.tls_session.tls_version:
177                s = self.tls_session
178                if s.selected_sig_alg:
179                    self.sig_alg = s.selected_sig_alg
180                elif s.tls_version < 0x0303:
181                    self.sig_alg = None
182                elif s.tls_version == 0x0304:
183                    # For TLS 1.3 signatures, set the signature
184                    # algorithm to RSA-PSS
185                    self.sig_alg = 0x0804
186
187    def post_dissection(self, r):
188        # for client
189        self.tls_session.selected_sig_alg = self.sig_alg
190
191    def _update_sig(self, m, key):
192        """
193        Sign 'm' with the PrivKey 'key' and update our own 'sig_val'.
194        Note that, even when 'sig_alg' is not None, we use the signature scheme
195        of the PrivKey (neither do we care to compare the both of them).
196        """
197        if self.sig_alg is None:
198            if self.tls_session.tls_version >= 0x0300:
199                self.sig_val = key.sign(m, t='pkcs', h='md5-sha1')
200            else:
201                self.sig_val = key.sign(m, t='pkcs', h='md5')
202        else:
203            if self.sig_alg in [0x0807, 0x0808]:  # ed25519, ed448
204                h, t = _tls_hash_sig[self.sig_alg], None
205            else:
206                h, sig = _tls_hash_sig[self.sig_alg].split('+')
207                if sig.endswith('pss'):
208                    t = "pss"
209                else:
210                    t = "pkcs"
211            self.sig_val = key.sign(m, t=t, h=h)
212
213    def _verify_sig(self, m, cert):
214        """
215        Verify that our own 'sig_val' carries the signature of 'm' by the
216        key associated to the Cert 'cert'.
217        """
218        if self.sig_val:
219            if self.sig_alg:
220                if self.sig_alg in [0x0807, 0x0808]:  # ed25519, ed448
221                    h, t = _tls_hash_sig[self.sig_alg], None
222                else:
223                    h, sig = _tls_hash_sig[self.sig_alg].split('+')
224                    if sig.endswith('pss'):
225                        t = "pss"
226                    else:
227                        t = "pkcs"
228                return cert.verify(m, self.sig_val, t=t, h=h)
229            else:
230                if self.tls_session.tls_version >= 0x0300:
231                    return cert.verify(m, self.sig_val, t='pkcs', h='md5-sha1')
232                else:
233                    return cert.verify(m, self.sig_val, t='pkcs', h='md5')
234        return False
235
236    def guess_payload_class(self, p):
237        return Padding
238
239
240class _TLSSignatureField(PacketField):
241    """
242    Used for 'digitally-signed struct' in several ServerKeyExchange,
243    and also in CertificateVerify. We can handle the anonymous case.
244    """
245    __slots__ = ["length_from"]
246
247    def __init__(self, name, default, length_from=None):
248        self.length_from = length_from
249        PacketField.__init__(self, name, default, _TLSSignature)
250
251    def m2i(self, pkt, m):
252        tmp_len = self.length_from(pkt)
253        if tmp_len == 0:
254            return None
255        return _TLSSignature(m, tls_session=pkt.tls_session)
256
257    def getfield(self, pkt, s):
258        i = self.m2i(pkt, s)
259        if i is None:
260            return s, None
261        remain = b""
262        if conf.padding_layer in i:
263            r = i[conf.padding_layer]
264            del r.underlayer.payload
265            remain = r.load
266        return remain, i
267
268
269class _TLSServerParamsField(PacketField):
270    """
271    This is a dispatcher for the Server*DHParams below, used in
272    TLSServerKeyExchange and based on the key_exchange.server_kx_msg_cls.
273    When this cls is None, it means that we should not see a ServerKeyExchange,
274    so we grab everything within length_from and make it available using Raw.
275
276    When the context has not been set (e.g. when no ServerHello was parsed or
277    dissected beforehand), we (kinda) clumsily set the cls by trial and error.
278    XXX We could use Serv*DHParams.check_params() once it has been implemented.
279    """
280    __slots__ = ["length_from"]
281
282    def __init__(self, name, default, length_from=None):
283        self.length_from = length_from
284        PacketField.__init__(self, name, default, None)
285
286    def m2i(self, pkt, m):
287        s = pkt.tls_session
288        tmp_len = self.length_from(pkt)
289        if s.prcs:
290            cls = s.prcs.key_exchange.server_kx_msg_cls(m)
291            if cls is None:
292                return Raw(m[:tmp_len]) / Padding(m[tmp_len:])
293            return cls(m, tls_session=s)
294        else:
295            try:
296                p = ServerDHParams(m, tls_session=s)
297                if pkcs_os2ip(p.load[:2]) not in _tls_hash_sig:
298                    raise Exception
299                return p
300            except Exception:
301                cls = _tls_server_ecdh_cls_guess(m)
302                p = cls(m, tls_session=s)
303                if pkcs_os2ip(p.load[:2]) not in _tls_hash_sig:
304                    return Raw(m[:tmp_len]) / Padding(m[tmp_len:])
305                return p
306
307
308###############################################################################
309#   Server Key Exchange parameters & value                                    #
310###############################################################################
311
312# Finite Field Diffie-Hellman
313
314class ServerDHParams(_GenericTLSSessionInheritance):
315    """
316    ServerDHParams for FFDH-based key exchanges, as defined in RFC 5246/7.4.3.
317
318    Either with .fill_missing() or .post_dissection(), the server_kx_privkey or
319    server_kx_pubkey of the TLS context are updated according to the
320    parsed/assembled values. It is the user's responsibility to store and
321    restore the original values if he wants to keep them. For instance, this
322    could be done between the writing of a ServerKeyExchange and the receiving
323    of a ClientKeyExchange (which includes secret generation).
324    """
325    name = "Server FFDH parameters"
326    fields_desc = [FieldLenField("dh_plen", None, length_of="dh_p"),
327                   StrLenField("dh_p", "",
328                               length_from=lambda pkt: pkt.dh_plen),
329                   FieldLenField("dh_glen", None, length_of="dh_g"),
330                   StrLenField("dh_g", "",
331                               length_from=lambda pkt: pkt.dh_glen),
332                   FieldLenField("dh_Yslen", None, length_of="dh_Ys"),
333                   StrLenField("dh_Ys", "",
334                               length_from=lambda pkt: pkt.dh_Yslen)]
335
336    @crypto_validator
337    def fill_missing(self):
338        """
339        We do not want TLSServerKeyExchange.build() to overload and recompute
340        things every time it is called. This method can be called specifically
341        to have things filled in a smart fashion.
342
343        Note that we do not expect default_params.g to be more than 0xff.
344        """
345        s = self.tls_session
346
347        default_params = _ffdh_groups['modp2048'][0].parameter_numbers()
348        default_mLen = _ffdh_groups['modp2048'][1]
349
350        if not self.dh_p:
351            self.dh_p = pkcs_i2osp(default_params.p, default_mLen // 8)
352        if self.dh_plen is None:
353            self.dh_plen = len(self.dh_p)
354        s.kx_group = "ffdhe%s" % (self.dh_plen * 8)
355
356        if not self.dh_g:
357            self.dh_g = pkcs_i2osp(default_params.g, 1)
358        if self.dh_glen is None:
359            self.dh_glen = 1
360
361        p = pkcs_os2ip(self.dh_p)
362        g = pkcs_os2ip(self.dh_g)
363        real_params = dh.DHParameterNumbers(p, g).parameters(default_backend())
364
365        if not self.dh_Ys:
366            s.server_kx_privkey = real_params.generate_private_key()
367            pubkey = s.server_kx_privkey.public_key()
368            y = pubkey.public_numbers().y
369            self.dh_Ys = pkcs_i2osp(y, pubkey.key_size // 8)
370        # else, we assume that the user wrote the server_kx_privkey by himself
371        if self.dh_Yslen is None:
372            self.dh_Yslen = len(self.dh_Ys)
373
374        if not s.client_kx_ffdh_params:
375            s.client_kx_ffdh_params = real_params
376
377    @crypto_validator
378    def register_pubkey(self):
379        """
380        XXX Check that the pubkey received is in the group.
381        """
382        p = pkcs_os2ip(self.dh_p)
383        g = pkcs_os2ip(self.dh_g)
384        pn = dh.DHParameterNumbers(p, g)
385
386        y = pkcs_os2ip(self.dh_Ys)
387        public_numbers = dh.DHPublicNumbers(y, pn)
388
389        s = self.tls_session
390        s.server_kx_pubkey = public_numbers.public_key(default_backend())
391        s.kx_group = "ffdhe%s" % (self.dh_plen * 8)
392
393        if not s.client_kx_ffdh_params:
394            s.client_kx_ffdh_params = pn.parameters(default_backend())
395
396    def post_dissection(self, r):
397        try:
398            self.register_pubkey()
399        except ImportError:
400            pass
401
402    def guess_payload_class(self, p):
403        """
404        The signature after the params gets saved as Padding.
405        This way, the .getfield() which _TLSServerParamsField inherits
406        from PacketField will return the signature remain as expected.
407        """
408        return Padding
409
410
411# Elliptic Curve Diffie-Hellman
412
413_tls_ec_curve_types = {1: "explicit_prime",
414                       2: "explicit_char2",
415                       3: "named_curve"}
416
417_tls_ec_basis_types = {0: "ec_basis_trinomial", 1: "ec_basis_pentanomial"}
418
419
420class ECCurvePkt(Packet):
421    name = "Elliptic Curve"
422    fields_desc = [FieldLenField("alen", None, length_of="a", fmt="B"),
423                   StrLenField("a", "", length_from=lambda pkt: pkt.alen),
424                   FieldLenField("blen", None, length_of="b", fmt="B"),
425                   StrLenField("b", "", length_from=lambda pkt: pkt.blen)]
426
427
428# Char2 Curves
429
430class ECTrinomialBasis(Packet):
431    name = "EC Trinomial Basis"
432    val = 0
433    fields_desc = [FieldLenField("klen", None, length_of="k", fmt="B"),
434                   StrLenField("k", "", length_from=lambda pkt: pkt.klen)]
435
436    def guess_payload_class(self, p):
437        return Padding
438
439
440class ECPentanomialBasis(Packet):
441    name = "EC Pentanomial Basis"
442    val = 1
443    fields_desc = [FieldLenField("k1len", None, length_of="k1", fmt="B"),
444                   StrLenField("k1", "", length_from=lambda pkt: pkt.k1len),
445                   FieldLenField("k2len", None, length_of="k2", fmt="B"),
446                   StrLenField("k2", "", length_from=lambda pkt: pkt.k2len),
447                   FieldLenField("k3len", None, length_of="k3", fmt="B"),
448                   StrLenField("k3", "", length_from=lambda pkt: pkt.k3len)]
449
450    def guess_payload_class(self, p):
451        return Padding
452
453
454_tls_ec_basis_cls = {0: ECTrinomialBasis, 1: ECPentanomialBasis}
455
456
457class _ECBasisTypeField(ByteEnumField):
458    __slots__ = ["basis_type_of"]
459
460    def __init__(self, name, default, enum, basis_type_of, remain=0):
461        self.basis_type_of = basis_type_of
462        EnumField.__init__(self, name, default, enum, "B")
463
464    def i2m(self, pkt, x):
465        if x is None:
466            fld, fval = pkt.getfield_and_val(self.basis_type_of)
467            x = fld.i2basis_type(pkt, fval)
468        return x
469
470
471class _ECBasisField(PacketField):
472    __slots__ = ["clsdict", "basis_type_from"]
473
474    def __init__(self, name, default, basis_type_from, clsdict):
475        self.clsdict = clsdict
476        self.basis_type_from = basis_type_from
477        PacketField.__init__(self, name, default, None)
478
479    def m2i(self, pkt, m):
480        basis = self.basis_type_from(pkt)
481        cls = self.clsdict[basis]
482        return cls(m)
483
484    def i2basis_type(self, pkt, x):
485        val = 0
486        try:
487            val = x.val
488        except Exception:
489            pass
490        return val
491
492
493# Distinct ECParameters
494##
495# To support the different ECParameters structures defined in Sect. 5.4 of
496# RFC 4492, we define 3 separates classes for implementing the 3 associated
497# ServerECDHParams: ServerECDHNamedCurveParams, ServerECDHExplicitPrimeParams
498# and ServerECDHExplicitChar2Params (support for this one is only partial).
499# The most frequent encounter of the 3 is (by far) ServerECDHNamedCurveParams.
500
501class ServerECDHExplicitPrimeParams(_GenericTLSSessionInheritance):
502    """
503    We provide parsing abilities for ExplicitPrimeParams, but there is no
504    support from the cryptography library, hence no context operations.
505    """
506    name = "Server ECDH parameters - Explicit Prime"
507    fields_desc = [ByteEnumField("curve_type", 1, _tls_ec_curve_types),
508                   FieldLenField("plen", None, length_of="p", fmt="B"),
509                   StrLenField("p", "", length_from=lambda pkt: pkt.plen),
510                   PacketField("curve", None, ECCurvePkt),
511                   FieldLenField("baselen", None, length_of="base", fmt="B"),
512                   StrLenField("base", "",
513                               length_from=lambda pkt: pkt.baselen),
514                   FieldLenField("orderlen", None,
515                                 length_of="order", fmt="B"),
516                   StrLenField("order", "",
517                               length_from=lambda pkt: pkt.orderlen),
518                   FieldLenField("cofactorlen", None,
519                                 length_of="cofactor", fmt="B"),
520                   StrLenField("cofactor", "",
521                               length_from=lambda pkt: pkt.cofactorlen),
522                   FieldLenField("pointlen", None,
523                                 length_of="point", fmt="B"),
524                   StrLenField("point", "",
525                               length_from=lambda pkt: pkt.pointlen)]
526
527    def fill_missing(self):
528        """
529        Note that if it is not set by the user, the cofactor will always
530        be 1. It is true for most, but not all, TLS elliptic curves.
531        """
532        if self.curve_type is None:
533            self.curve_type = _tls_ec_curve_types["explicit_prime"]
534
535    def guess_payload_class(self, p):
536        return Padding
537
538
539class ServerECDHExplicitChar2Params(_GenericTLSSessionInheritance):
540    """
541    We provide parsing abilities for Char2Params, but there is no
542    support from the cryptography library, hence no context operations.
543    """
544    name = "Server ECDH parameters - Explicit Char2"
545    fields_desc = [ByteEnumField("curve_type", 2, _tls_ec_curve_types),
546                   ShortField("m", None),
547                   _ECBasisTypeField("basis_type", None,
548                                     _tls_ec_basis_types, "basis"),
549                   _ECBasisField("basis", ECTrinomialBasis(),
550                                 lambda pkt: pkt.basis_type,
551                                 _tls_ec_basis_cls),
552                   PacketField("curve", ECCurvePkt(), ECCurvePkt),
553                   FieldLenField("baselen", None, length_of="base", fmt="B"),
554                   StrLenField("base", "",
555                               length_from=lambda pkt: pkt.baselen),
556                   ByteField("order", None),
557                   ByteField("cofactor", None),
558                   FieldLenField("pointlen", None,
559                                 length_of="point", fmt="B"),
560                   StrLenField("point", "",
561                               length_from=lambda pkt: pkt.pointlen)]
562
563    def fill_missing(self):
564        if self.curve_type is None:
565            self.curve_type = _tls_ec_curve_types["explicit_char2"]
566
567    def guess_payload_class(self, p):
568        return Padding
569
570
571class ServerECDHNamedCurveParams(_GenericTLSSessionInheritance):
572    name = "Server ECDH parameters - Named Curve"
573    fields_desc = [ByteEnumField("curve_type", 3, _tls_ec_curve_types),
574                   ShortEnumField("named_curve", None, _tls_named_curves),
575                   FieldLenField("pointlen", None,
576                                 length_of="point", fmt="B"),
577                   StrLenField("point", None,
578                               length_from=lambda pkt: pkt.pointlen)]
579
580    @crypto_validator
581    def fill_missing(self):
582        """
583        We do not want TLSServerKeyExchange.build() to overload and recompute
584        things every time it is called. This method can be called specifically
585        to have things filled in a smart fashion.
586
587        XXX We should account for the point_format (before 'point' filling).
588        """
589        s = self.tls_session
590
591        if self.curve_type is None:
592            self.curve_type = _tls_ec_curve_types["named_curve"]
593
594        if self.named_curve is None:
595            self.named_curve = 23
596
597        curve_group = self.named_curve
598        if curve_group not in _tls_named_curves:
599            # this fallback is arguable
600            curve_group = 23  # default to secp256r1
601        s.server_kx_privkey = _tls_named_groups_generate(curve_group)
602        s.kx_group = _tls_named_curves.get(curve_group, str(curve_group))
603
604        if self.point is None:
605            self.point = _tls_named_groups_pubbytes(
606                s.server_kx_privkey
607            )
608
609        # else, we assume that the user wrote the server_kx_privkey by himself
610        if self.pointlen is None:
611            self.pointlen = len(self.point)
612
613        if not s.client_kx_ecdh_params:
614            s.client_kx_ecdh_params = curve_group
615
616    @crypto_validator
617    def register_pubkey(self):
618        """
619        XXX Support compressed point format.
620        XXX Check that the pubkey received is on the curve.
621        """
622        # point_format = 0
623        # if self.point[0] in [b'\x02', b'\x03']:
624        #    point_format = 1
625
626        s = self.tls_session
627        s.server_kx_pubkey = _tls_named_groups_import(
628            self.named_curve,
629            self.point
630        )
631        s.kx_group = _tls_named_curves.get(self.named_curve, str(self.named_curve))
632
633        if not s.client_kx_ecdh_params:
634            s.client_kx_ecdh_params = self.named_curve
635
636    def post_dissection(self, r):
637        try:
638            self.register_pubkey()
639        except ImportError:
640            pass
641
642    def guess_payload_class(self, p):
643        return Padding
644
645
646_tls_server_ecdh_cls = {1: ServerECDHExplicitPrimeParams,
647                        2: ServerECDHExplicitChar2Params,
648                        3: ServerECDHNamedCurveParams}
649
650
651def _tls_server_ecdh_cls_guess(m):
652    if not m:
653        return None
654    curve_type = orb(m[0])
655    return _tls_server_ecdh_cls.get(curve_type, None)
656
657
658# RSA Encryption (export)
659
660class ServerRSAParams(_GenericTLSSessionInheritance):
661    """
662    Defined for RSA_EXPORT kx : it enables servers to share RSA keys shorter
663    than their principal {>512}-bit key, when it is not allowed for kx.
664
665    This should not appear in standard RSA kx negotiation, as the key
666    has already been advertised in the Certificate message.
667    """
668    name = "Server RSA_EXPORT parameters"
669    fields_desc = [FieldLenField("rsamodlen", None, length_of="rsamod"),
670                   StrLenField("rsamod", "",
671                               length_from=lambda pkt: pkt.rsamodlen),
672                   FieldLenField("rsaexplen", None, length_of="rsaexp"),
673                   StrLenField("rsaexp", "",
674                               length_from=lambda pkt: pkt.rsaexplen)]
675
676    @crypto_validator
677    def fill_missing(self):
678        k = PrivKeyRSA()
679        k.fill_and_store(modulusLen=512)
680        self.tls_session.server_tmp_rsa_key = k
681        pubNum = k.pubkey.public_numbers()
682
683        if not self.rsamod:
684            self.rsamod = pkcs_i2osp(pubNum.n, k.pubkey.key_size // 8)
685        if self.rsamodlen is None:
686            self.rsamodlen = len(self.rsamod)
687
688        self.tls_session.kx_group = "rsa%s" % self.rsamodlen
689
690        rsaexplen = math.ceil(math.log(pubNum.e) / math.log(2) / 8.)
691        if not self.rsaexp:
692            self.rsaexp = pkcs_i2osp(pubNum.e, rsaexplen)
693        if self.rsaexplen is None:
694            self.rsaexplen = len(self.rsaexp)
695
696    @crypto_validator
697    def register_pubkey(self):
698        mLen = self.rsamodlen
699        m = self.rsamod
700        e = self.rsaexp
701        self.tls_session.server_tmp_rsa_key = PubKeyRSA((e, m, mLen))
702        self.tls_session.kx_group = "rsa%s" % mLen
703
704    def post_dissection(self, pkt):
705        try:
706            self.register_pubkey()
707        except ImportError:
708            pass
709
710    def guess_payload_class(self, p):
711        return Padding
712
713
714# Pre-Shared Key
715
716class ServerPSKParams(Packet):
717    """
718    XXX We provide some parsing abilities for ServerPSKParams, but the
719    context operations have not been implemented yet. See RFC 4279.
720    Note that we do not cover the (EC)DHE_PSK key exchange,
721    which should contain a Server*DHParams after 'psk_identity_hint'.
722    """
723    name = "Server PSK parameters"
724    fields_desc = [FieldLenField("psk_identity_hint_len", None,
725                                 length_of="psk_identity_hint", fmt="!H"),
726                   StrLenField("psk_identity_hint", "",
727                               length_from=lambda pkt: pkt.psk_identity_hint_len)]  # noqa: E501
728
729    def fill_missing(self):
730        pass
731
732    def post_dissection(self, pkt):
733        pass
734
735    def guess_payload_class(self, p):
736        return Padding
737
738
739###############################################################################
740#   Client Key Exchange value                                                 #
741###############################################################################
742
743# FFDH/ECDH
744
745class ClientDiffieHellmanPublic(_GenericTLSSessionInheritance):
746    """
747    If the user provides a value for dh_Yc attribute, we assume he will set
748    the pms and ms accordingly and trigger the key derivation on his own.
749
750    XXX As specified in 7.4.7.2. of RFC 4346, we should distinguish the needs
751    for implicit or explicit value depending on availability of DH parameters
752    in *client* certificate. For now we can only do ephemeral/explicit DH.
753    """
754    name = "Client DH Public Value"
755    fields_desc = [FieldLenField("dh_Yclen", None, length_of="dh_Yc"),
756                   StrLenField("dh_Yc", "",
757                               length_from=lambda pkt: pkt.dh_Yclen)]
758
759    @crypto_validator
760    def fill_missing(self):
761        s = self.tls_session
762        s.client_kx_privkey = s.client_kx_ffdh_params.generate_private_key()
763        pubkey = s.client_kx_privkey.public_key()
764        y = pubkey.public_numbers().y
765        self.dh_Yc = pkcs_i2osp(y, pubkey.key_size // 8)
766
767        if s.client_kx_privkey and s.server_kx_pubkey:
768            pms = s.client_kx_privkey.exchange(s.server_kx_pubkey)
769            s.pre_master_secret = pms.lstrip(b"\x00")
770            if not s.extms:
771                # If extms is set (extended master secret), the key will
772                # need the session hash to be computed. This is provided
773                # by the TLSClientKeyExchange. Same in all occurrences
774                s.compute_ms_and_derive_keys()
775
776    def post_build(self, pkt, pay):
777        if not self.dh_Yc:
778            try:
779                self.fill_missing()
780            except ImportError:
781                pass
782        if self.dh_Yclen is None:
783            self.dh_Yclen = len(self.dh_Yc)
784        return pkcs_i2osp(self.dh_Yclen, 2) + self.dh_Yc + pay
785
786    def post_dissection(self, m):
787        """
788        First we update the client DHParams. Then, we try to update the server
789        DHParams generated during Server*DHParams building, with the shared
790        secret. Finally, we derive the session keys and update the context.
791        """
792        s = self.tls_session
793
794        # if there are kx params and keys, we assume the crypto library is ok
795        if s.client_kx_ffdh_params:
796            y = pkcs_os2ip(self.dh_Yc)
797            param_numbers = s.client_kx_ffdh_params.parameter_numbers()
798            public_numbers = dh.DHPublicNumbers(y, param_numbers)
799            s.client_kx_pubkey = public_numbers.public_key(default_backend())
800
801        if s.server_kx_privkey and s.client_kx_pubkey:
802            ZZ = s.server_kx_privkey.exchange(s.client_kx_pubkey)
803            s.pre_master_secret = ZZ.lstrip(b"\x00")
804            if not s.extms:
805                s.compute_ms_and_derive_keys()
806
807    def guess_payload_class(self, p):
808        return Padding
809
810
811class ClientECDiffieHellmanPublic(_GenericTLSSessionInheritance):
812    """
813    Note that the 'len' field is 1 byte longer than with the previous class.
814    """
815    name = "Client ECDH Public Value"
816    fields_desc = [FieldLenField("ecdh_Yclen", None,
817                                 length_of="ecdh_Yc", fmt="B"),
818                   StrLenField("ecdh_Yc", "",
819                               length_from=lambda pkt: pkt.ecdh_Yclen)]
820
821    @crypto_validator
822    def fill_missing(self):
823        s = self.tls_session
824        s.client_kx_privkey = _tls_named_groups_generate(
825            s.client_kx_ecdh_params
826        )
827        # ecdh_Yc follows ECPoint.point format as defined in
828        # https://tools.ietf.org/html/rfc8422#section-5.4
829        pubkey = s.client_kx_privkey.public_key()
830        if isinstance(pubkey, (x25519.X25519PublicKey,
831                               x448.X448PublicKey)):
832            self.ecdh_Yc = pubkey.public_bytes(
833                serialization.Encoding.Raw,
834                serialization.PublicFormat.Raw
835            )
836            if s.client_kx_privkey and s.server_kx_pubkey:
837                pms = s.client_kx_privkey.exchange(s.server_kx_pubkey)
838        else:
839            # uncompressed format of an elliptic curve point
840            x = pubkey.public_numbers().x
841            y = pubkey.public_numbers().y
842            self.ecdh_Yc = (b"\x04" +
843                            pkcs_i2osp(x, pubkey.key_size // 8) +
844                            pkcs_i2osp(y, pubkey.key_size // 8))
845            if s.client_kx_privkey and s.server_kx_pubkey:
846                pms = s.client_kx_privkey.exchange(ec.ECDH(),
847                                                   s.server_kx_pubkey)
848
849        if s.client_kx_privkey and s.server_kx_pubkey:
850            s.pre_master_secret = pms
851            if not s.extms:
852                s.compute_ms_and_derive_keys()
853
854    def post_build(self, pkt, pay):
855        if not self.ecdh_Yc:
856            try:
857                self.fill_missing()
858            except ImportError:
859                pass
860        if self.ecdh_Yclen is None:
861            self.ecdh_Yclen = len(self.ecdh_Yc)
862        return pkcs_i2osp(self.ecdh_Yclen, 1) + self.ecdh_Yc + pay
863
864    def post_dissection(self, m):
865        s = self.tls_session
866
867        # if there are kx params and keys, we assume the crypto library is ok
868        if s.client_kx_ecdh_params:
869            s.client_kx_pubkey = _tls_named_groups_import(
870                s.client_kx_ecdh_params,
871                self.ecdh_Yc
872            )
873
874        if s.server_kx_privkey and s.client_kx_pubkey:
875            ZZ = s.server_kx_privkey.exchange(ec.ECDH(), s.client_kx_pubkey)
876            s.pre_master_secret = ZZ
877            if not s.extms:
878                s.compute_ms_and_derive_keys()
879
880
881# RSA Encryption (standard & export)
882
883class _UnEncryptedPreMasterSecret(Raw):
884    """
885    When the content of an EncryptedPreMasterSecret could not be deciphered,
886    we use this class to represent the encrypted data.
887    """
888    name = "RSA Encrypted PreMaster Secret (protected)"
889
890    def __init__(self, *args, **kargs):
891        kargs.pop('tls_session', None)
892        return super(_UnEncryptedPreMasterSecret, self).__init__(*args, **kargs)  # noqa: E501
893
894
895class EncryptedPreMasterSecret(_GenericTLSSessionInheritance):
896    """
897    Pay attention to implementation notes in section 7.4.7.1 of RFC 5246.
898    """
899    name = "RSA Encrypted PreMaster Secret"
900    fields_desc = [_TLSClientVersionField("client_version", None,
901                                          _tls_version),
902                   StrFixedLenField("random", None, 46)]
903
904    @classmethod
905    def dispatch_hook(cls, _pkt=None, *args, **kargs):
906        if _pkt and 'tls_session' in kargs:
907            s = kargs['tls_session']
908            if s.server_tmp_rsa_key is None and s.server_rsa_key is None:
909                return _UnEncryptedPreMasterSecret
910        return EncryptedPreMasterSecret
911
912    def pre_dissect(self, m):
913        s = self.tls_session
914        tbd = m
915        tls_version = s.tls_version
916        if tls_version is None:
917            tls_version = s.advertised_tls_version
918        if tls_version >= 0x0301:
919            if len(m) < 2:      # Should not happen
920                return m
921            tmp_len = struct.unpack("!H", m[:2])[0]
922            if len(m) != tmp_len + 2:
923                err = "TLS 1.0+, but RSA Encrypted PMS with no explicit length"
924                warning(err)
925            else:
926                tbd = m[2:]
927        if s.server_tmp_rsa_key is not None:
928            # priority is given to the tmp_key, if there is one
929            decrypted = s.server_tmp_rsa_key.decrypt(tbd)
930            pms = decrypted[-48:]
931        elif s.server_rsa_key is not None:
932            decrypted = s.server_rsa_key.decrypt(tbd)
933            pms = decrypted[-48:]
934        else:
935            # the dispatch_hook is supposed to prevent this case
936            pms = b"\x00" * 48
937            err = "No server RSA key to decrypt Pre Master Secret. Skipping."
938            warning(err)
939
940        s.pre_master_secret = pms
941        if not s.extms:
942            s.compute_ms_and_derive_keys()
943
944        return pms
945
946    def post_build(self, pkt, pay):
947        """
948        We encrypt the premaster secret (the 48 bytes) with either the server
949        certificate or the temporary RSA key provided in a server key exchange
950        message. After that step, we add the 2 bytes to provide the length, as
951        described in implementation notes at the end of section 7.4.7.1.
952        """
953        enc = pkt
954
955        s = self.tls_session
956        s.pre_master_secret = enc
957        if not s.extms:
958            s.compute_ms_and_derive_keys()
959
960        if s.server_tmp_rsa_key is not None:
961            enc = s.server_tmp_rsa_key.encrypt(pkt, t="pkcs")
962        elif s.server_certs is not None and len(s.server_certs) > 0:
963            enc = s.server_certs[0].encrypt(pkt, t="pkcs")
964        else:
965            warning("No material to encrypt Pre Master Secret")
966
967        tmp_len = b""
968        tls_version = s.tls_version
969        if tls_version is None:
970            tls_version = s.advertised_tls_version
971        if tls_version >= 0x0301:
972            tmp_len = struct.pack("!H", len(enc))
973        return tmp_len + enc + pay
974
975    def guess_payload_class(self, p):
976        return Padding
977
978
979# Pre-Shared Key
980
981class ClientPSKIdentity(Packet):
982    """
983    XXX We provide parsing abilities for ServerPSKParams, but the context
984    operations have not been implemented yet. See RFC 4279.
985    Note that we do not cover the (EC)DHE_PSK nor the RSA_PSK key exchange,
986    which should contain either an EncryptedPMS or a ClientDiffieHellmanPublic.
987    """
988    name = "Server PSK parameters"
989    fields_desc = [FieldLenField("psk_identity_len", None,
990                                 length_of="psk_identity", fmt="!H"),
991                   StrLenField("psk_identity", "",
992                               length_from=lambda pkt: pkt.psk_identity_len)]
993