• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) 2007, 2008, 2009 Arnaud Ebalard
5#               2015, 2016, 2017 Maxence Tury
6
7"""
8The _TLSAutomaton class provides methods common to both TLS client and server.
9"""
10
11import select
12import socket
13import struct
14
15from scapy.automaton import Automaton
16from scapy.config import conf
17from scapy.error import log_interactive
18from scapy.packet import Raw
19from scapy.layers.tls.basefields import _tls_type
20from scapy.layers.tls.cert import Cert, PrivKey
21from scapy.layers.tls.record import TLS
22from scapy.layers.tls.record_sslv2 import SSLv2
23from scapy.layers.tls.record_tls13 import TLS13
24
25
26class _TLSAutomaton(Automaton):
27    """
28    SSLv3 and TLS 1.0-1.2 typically need a 2-RTT handshake:
29
30    Client        Server
31      | --------->>> |    C1 - ClientHello
32      | <<<--------- |    S1 - ServerHello
33      | <<<--------- |    S1 - Certificate
34      | <<<--------- |    S1 - ServerKeyExchange
35      | <<<--------- |    S1 - ServerHelloDone
36      | --------->>> |    C2 - ClientKeyExchange
37      | --------->>> |    C2 - ChangeCipherSpec
38      | --------->>> |    C2 - Finished [encrypted]
39      | <<<--------- |    S2 - ChangeCipherSpec
40      | <<<--------- |    S2 - Finished [encrypted]
41
42    We call these successive groups of messages:
43    ClientFlight1, ServerFlight1, ClientFlight2 and ServerFlight2.
44
45    With TLS 1.3, the handshake require only 1-RTT:
46
47    Client        Server
48      | --------->>> |    C1 - ClientHello
49      | <<<--------- |    S1 - ServerHello
50      | <<<--------- |    S1 - Certificate [encrypted]
51      | <<<--------- |    S1 - CertificateVerify [encrypted]
52      | <<<--------- |    S1 - Finished [encrypted]
53      | --------->>> |    C2 - Finished [encrypted]
54
55    We want to send our messages from the same flight all at once through the
56    socket. This is achieved by managing a list of records in 'buffer_out'.
57    We may put several messages (i.e. what RFC 5246 calls the record fragments)
58    in the same record when possible, but we may need several records for the
59    same flight, as with ClientFlight2.
60
61    However, note that the flights from the opposite side may be spread wildly
62    across TLS records and TCP packets. This is why we use a 'get_next_msg'
63    method for feeding a list of received messages, 'buffer_in'. Raw data
64    which has not yet been interpreted as a TLS record is kept in 'remain_in'.
65    """
66
67    def __init__(self, *args, **kwargs):
68        kwargs["ll"] = lambda *args, **kwargs: None
69        kwargs["recvsock"] = lambda *args, **kwargs: None
70        super(_TLSAutomaton, self).__init__(*args, **kwargs)
71
72    def parse_args(self, mycert=None, mykey=None, **kargs):
73
74        self.verbose = kargs.pop("verbose", True)
75
76        super(_TLSAutomaton, self).parse_args(**kargs)
77
78        self.socket = None
79        self.remain_in = b""
80        self.buffer_in = []         # these are 'fragments' inside records
81        self.buffer_out = []        # these are records
82
83        self.cur_session = None
84        self.cur_pkt = None         # this is usually the latest parsed packet
85
86        if mycert:
87            self.mycert = Cert(mycert)
88        else:
89            self.mycert = None
90
91        if mykey:
92            self.mykey = PrivKey(mykey)
93        else:
94            self.mykey = None
95
96    def get_next_msg(self, socket_timeout=2, retry=2):
97        """
98        The purpose of the function is to make next message(s) available in
99        self.buffer_in. If the list is not empty, nothing is done. If not, in
100        order to fill it, the function uses the data already available in
101        self.remain_in from a previous call and waits till there are enough to
102        dissect a TLS packet. Once dissected, the content of the TLS packet
103        (carried messages, or 'fragments') is appended to self.buffer_in.
104
105        We have to grab enough data to dissect a TLS packet. We start by
106        reading the first 2 bytes. Unless we get anything different from
107        \\x14\\x03, \\x15\\x03, \\x16\\x03 or \\x17\\x03 (which might indicate
108        an SSLv2 record, whose first 2 bytes encode the length), we retrieve
109        3 more bytes in order to get the length of the TLS record, and
110        finally we can retrieve the remaining of the record.
111        """
112        if self.buffer_in:
113            # A message is already available.
114            return
115
116        is_sslv2_msg = False
117        still_getting_len = True
118        grablen = 2
119        while retry and (still_getting_len or len(self.remain_in) < grablen):
120            if not is_sslv2_msg and grablen == 5 and len(self.remain_in) >= 5:
121                grablen = struct.unpack('!H', self.remain_in[3:5])[0] + 5
122                still_getting_len = False
123            elif grablen == 2 and len(self.remain_in) >= 2:
124                byte0, byte1 = struct.unpack("BB", self.remain_in[:2])
125                if (byte0 in _tls_type) and (byte1 == 3):
126                    # Retry following TLS scheme. This will cause failure
127                    # for SSLv2 packets with length 0x1{4-7}03.
128                    grablen = 5
129                else:
130                    # Extract the SSLv2 length.
131                    is_sslv2_msg = True
132                    still_getting_len = False
133                    if byte0 & 0x80:
134                        grablen = 2 + 0 + ((byte0 & 0x7f) << 8) + byte1
135                    else:
136                        grablen = 2 + 1 + ((byte0 & 0x3f) << 8) + byte1
137            elif not is_sslv2_msg and grablen == 5 and len(self.remain_in) >= 5:  # noqa: E501
138                grablen = struct.unpack('!H', self.remain_in[3:5])[0] + 5
139
140            if grablen == len(self.remain_in):
141                break
142
143            final = False
144            try:
145                tmp, _, _ = select.select([self.socket], [], [],
146                                          socket_timeout)
147                if not tmp:
148                    retry -= 1
149                else:
150                    data = tmp[0].recv(grablen - len(self.remain_in))
151                    if not data:
152                        # Socket peer was closed
153                        self.vprint("Peer socket closed !")
154                        final = True
155                    else:
156                        self.remain_in += data
157            except Exception as ex:
158                if not isinstance(ex, socket.timeout):
159                    self.vprint("Could not join host (%s) ! Retrying..." % ex)
160                retry -= 1
161            else:
162                if final:
163                    raise self.SOCKET_CLOSED()
164
165        if len(self.remain_in) < 2 or len(self.remain_in) != grablen:
166            # Remote peer is not willing to respond
167            return
168
169        if (byte0 == 0x17 and
170                (self.cur_session.advertised_tls_version >= 0x0304 or
171                 self.cur_session.tls_version >= 0x0304)):
172            p = TLS13(self.remain_in, tls_session=self.cur_session)
173            self.remain_in = b""
174            self.buffer_in += p.inner.msg
175        else:
176            p = TLS(self.remain_in, tls_session=self.cur_session)
177            self.cur_session = p.tls_session
178            self.remain_in = b""
179            if isinstance(p, SSLv2) and not p.msg:
180                p.msg = Raw("")
181            if self.cur_session.tls_version is None or \
182               self.cur_session.tls_version < 0x0304:
183                self.buffer_in += p.msg
184            else:
185                if isinstance(p, TLS13):
186                    self.buffer_in += p.inner.msg
187                else:
188                    # should be TLS13ServerHello only
189                    self.buffer_in += p.msg
190
191        while p.payload:
192            if isinstance(p.payload, Raw):
193                self.remain_in += p.payload.load
194                p = p.payload
195            elif isinstance(p.payload, TLS):
196                p = p.payload
197                if self.cur_session.tls_version is None or \
198                   self.cur_session.tls_version < 0x0304:
199                    self.buffer_in += p.msg
200                else:
201                    self.buffer_in += p.inner.msg
202            else:
203                p = p.payload
204
205    def raise_on_packet(self, pkt_cls, state, get_next_msg=True):
206        """
207        If the next message to be processed has type 'pkt_cls', raise 'state'.
208        If there is no message waiting to be processed, we try to get one with
209        the default 'get_next_msg' parameters.
210        """
211        # Maybe we already parsed the expected packet, maybe not.
212        if get_next_msg:
213            self.get_next_msg()
214        if (not self.buffer_in or
215                not isinstance(self.buffer_in[0], pkt_cls)):
216            return
217        self.cur_pkt = self.buffer_in[0]
218        self.buffer_in = self.buffer_in[1:]
219        raise state()
220
221    def in_handshake(self, pkt_cls):
222        """
223        Return True if the pkt_cls was present during the handshake.
224        This is used to detect whether Certificates were requested, etc.
225        """
226        return any(
227            isinstance(m, pkt_cls)
228            for m in self.cur_session.handshake_messages_parsed
229        )
230
231    def add_record(self, is_sslv2=None, is_tls13=None, is_tls12=None):
232        """
233        Add a new TLS or SSLv2 or TLS 1.3 record to the packets buffered out.
234        """
235        if is_sslv2 is None and is_tls13 is None and is_tls12 is None:
236            v = (self.cur_session.tls_version or
237                 self.cur_session.advertised_tls_version)
238            if v in [0x0200, 0x0002]:
239                is_sslv2 = True
240            elif v >= 0x0304:
241                is_tls13 = True
242        if is_sslv2:
243            self.buffer_out.append(SSLv2(tls_session=self.cur_session))
244        elif is_tls13:
245            self.buffer_out.append(TLS13(tls_session=self.cur_session))
246        # For TLS 1.3 middlebox compatibility, TLS record version must
247        # be 0x0303
248        elif is_tls12:
249            self.buffer_out.append(TLS(version="TLS 1.2",
250                                       tls_session=self.cur_session))
251        else:
252            self.buffer_out.append(TLS(tls_session=self.cur_session))
253
254    def add_msg(self, pkt):
255        """
256        Add a TLS message (e.g. TLSClientHello or TLSApplicationData)
257        inside the latest record to be sent through the socket.
258        We believe a good automaton should not use the first test.
259        """
260        if not self.buffer_out:
261            self.add_record()
262        r = self.buffer_out[-1]
263        if isinstance(r, TLS13):
264            self.buffer_out[-1].inner.msg.append(pkt)
265        else:
266            self.buffer_out[-1].msg.append(pkt)
267
268    def flush_records(self):
269        """
270        Send all buffered records and update the session accordingly.
271        """
272        s = b"".join(p.raw_stateful() for p in self.buffer_out)
273        self.socket.send(s)
274        self.buffer_out = []
275
276    def vprint(self, s=""):
277        if self.verbose:
278            if conf.interactive:
279                log_interactive.info("> %s", s)
280            else:
281                print("> %s" % s)
282