• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Philippe Biondi <phil@secdev.org>
5
6"""
7SuperSocket.
8"""
9
10from select import select, error as select_error
11import ctypes
12import errno
13import socket
14import struct
15import time
16
17from scapy.config import conf
18from scapy.consts import DARWIN, WINDOWS
19from scapy.data import (
20    MTU,
21    ETH_P_IP,
22    ETH_P_IPV6,
23    SOL_PACKET,
24    SO_TIMESTAMPNS,
25)
26from scapy.compat import raw
27from scapy.error import warning, log_runtime
28from scapy.interfaces import network_name
29from scapy.packet import Packet, NoPayload
30from scapy.plist import (
31    PacketList,
32    SndRcvList,
33    _PacketIterable,
34)
35from scapy.utils import PcapReader, tcpdump
36
37# Typing imports
38from scapy.interfaces import _GlobInterfaceType
39from typing import (
40    Any,
41    Dict,
42    Iterator,
43    List,
44    Optional,
45    Tuple,
46    Type,
47    cast,
48)
49
50# Utils
51
52
53class _SuperSocket_metaclass(type):
54    desc = None   # type: Optional[str]
55
56    def __repr__(self):
57        # type: () -> str
58        if self.desc is not None:
59            return "<%s: %s>" % (self.__name__, self.desc)
60        else:
61            return "<%s>" % self.__name__
62
63
64# Used to get ancillary data
65PACKET_AUXDATA = 8
66ETH_P_8021Q = 0x8100
67TP_STATUS_VLAN_VALID = 1 << 4
68TP_STATUS_VLAN_TPID_VALID = 1 << 6
69
70
71class tpacket_auxdata(ctypes.Structure):
72    _fields_ = [
73        ("tp_status", ctypes.c_uint),
74        ("tp_len", ctypes.c_uint),
75        ("tp_snaplen", ctypes.c_uint),
76        ("tp_mac", ctypes.c_ushort),
77        ("tp_net", ctypes.c_ushort),
78        ("tp_vlan_tci", ctypes.c_ushort),
79        ("tp_vlan_tpid", ctypes.c_ushort),
80    ]  # type: List[Tuple[str, Any]]
81
82
83# SuperSocket
84
85class SuperSocket(metaclass=_SuperSocket_metaclass):
86    closed = False  # type: bool
87    nonblocking_socket = False  # type: bool
88    auxdata_available = False   # type: bool
89
90    def __init__(self,
91                 family=socket.AF_INET,  # type: int
92                 type=socket.SOCK_STREAM,  # type: int
93                 proto=0,  # type: int
94                 iface=None,  # type: Optional[_GlobInterfaceType]
95                 **kwargs  # type: Any
96                 ):
97        # type: (...) -> None
98        self.ins = socket.socket(family, type, proto)  # type: socket.socket
99        self.outs = self.ins  # type: Optional[socket.socket]
100        self.promisc = conf.sniff_promisc
101        self.iface = iface or conf.iface
102
103    def send(self, x):
104        # type: (Packet) -> int
105        """Sends a `Packet` object
106
107        :param x: `Packet` to be send
108        :return: Number of bytes that have been sent
109        """
110        sx = raw(x)
111        try:
112            x.sent_time = time.time()
113        except AttributeError:
114            pass
115
116        if self.outs:
117            return self.outs.send(sx)
118        else:
119            return 0
120
121    if WINDOWS:
122        def _recv_raw(self, sock, x):
123            # type: (socket.socket, int) -> Tuple[bytes, Any, Optional[float]]
124            """Internal function to receive a Packet.
125
126            :param sock: Socket object from which data are received
127            :param x: Number of bytes to be received
128            :return: Received bytes, address information and no timestamp
129            """
130            pkt, sa_ll = sock.recvfrom(x)
131            return pkt, sa_ll, None
132    else:
133        def _recv_raw(self, sock, x):
134            # type: (socket.socket, int) -> Tuple[bytes, Any, Optional[float]]
135            """Internal function to receive a Packet,
136            and process ancillary data.
137
138            :param sock: Socket object from which data are received
139            :param x: Number of bytes to be received
140            :return: Received bytes, address information and an optional timestamp
141            """
142            timestamp = None
143            if not self.auxdata_available:
144                pkt, _, _, sa_ll = sock.recvmsg(x)
145                return pkt, sa_ll, timestamp
146            flags_len = socket.CMSG_LEN(4096)
147            pkt, ancdata, flags, sa_ll = sock.recvmsg(x, flags_len)
148            if not pkt:
149                return pkt, sa_ll, timestamp
150            for cmsg_lvl, cmsg_type, cmsg_data in ancdata:
151                # Check available ancillary data
152                if (cmsg_lvl == SOL_PACKET and cmsg_type == PACKET_AUXDATA):
153                    # Parse AUXDATA
154                    try:
155                        auxdata = tpacket_auxdata.from_buffer_copy(cmsg_data)
156                    except ValueError:
157                        # Note: according to Python documentation, recvmsg()
158                        #       can return a truncated message. A ValueError
159                        #       exception likely indicates that Auxiliary
160                        #       Data is not supported by the Linux kernel.
161                        return pkt, sa_ll, timestamp
162                    if auxdata.tp_vlan_tci != 0 or \
163                            auxdata.tp_status & TP_STATUS_VLAN_VALID:
164                        # Insert VLAN tag
165                        tpid = ETH_P_8021Q
166                        if auxdata.tp_status & TP_STATUS_VLAN_TPID_VALID:
167                            tpid = auxdata.tp_vlan_tpid
168                        tag = struct.pack(
169                            "!HH",
170                            tpid,
171                            auxdata.tp_vlan_tci
172                        )
173                        pkt = pkt[:12] + tag + pkt[12:]
174                elif cmsg_lvl == socket.SOL_SOCKET and \
175                        cmsg_type == SO_TIMESTAMPNS:
176                    length = len(cmsg_data)
177                    if length == 16:  # __kernel_timespec
178                        tmp = struct.unpack("ll", cmsg_data)
179                    elif length == 8:  # timespec
180                        tmp = struct.unpack("ii", cmsg_data)
181                    else:
182                        log_runtime.warning("Unknown timespec format.. ?!")
183                        continue
184                    timestamp = tmp[0] + tmp[1] * 1e-9
185            return pkt, sa_ll, timestamp
186
187    def recv_raw(self, x=MTU):
188        # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]]  # noqa: E501
189        """Returns a tuple containing (cls, pkt_data, time)
190
191
192        :param x: Maximum number of bytes to be received, defaults to MTU
193        :return: A tuple, consisting of a Packet type, the received data,
194                 and a timestamp
195        """
196        return conf.raw_layer, self.ins.recv(x), None
197
198    def recv(self, x=MTU, **kwargs):
199        # type: (int, **Any) -> Optional[Packet]
200        """Receive a Packet according to the `basecls` of this socket
201
202        :param x: Maximum number of bytes to be received, defaults to MTU
203        :return: The received `Packet` object, or None
204        """
205        cls, val, ts = self.recv_raw(x)
206        if not val or not cls:
207            return None
208        try:
209            pkt = cls(val, **kwargs)  # type: Packet
210        except KeyboardInterrupt:
211            raise
212        except Exception:
213            if conf.debug_dissector:
214                from scapy.sendrecv import debug
215                debug.crashed_on = (cls, val)
216                raise
217            pkt = conf.raw_layer(val)
218        if ts:
219            pkt.time = ts
220        return pkt
221
222    def fileno(self):
223        # type: () -> int
224        return self.ins.fileno()
225
226    def close(self):
227        # type: () -> None
228        """Gracefully close this socket
229        """
230        if self.closed:
231            return
232        self.closed = True
233        if getattr(self, "outs", None):
234            if getattr(self, "ins", None) != self.outs:
235                if self.outs and self.outs.fileno() != -1:
236                    self.outs.close()
237        if getattr(self, "ins", None):
238            if self.ins.fileno() != -1:
239                self.ins.close()
240
241    def sr(self, *args, **kargs):
242        # type: (Any, Any) -> Tuple[SndRcvList, PacketList]
243        """Send and Receive multiple packets
244        """
245        from scapy import sendrecv
246        return sendrecv.sndrcv(self, *args, **kargs)
247
248    def sr1(self, *args, **kargs):
249        # type: (Any, Any) -> Optional[Packet]
250        """Send one packet and receive one answer
251        """
252        from scapy import sendrecv
253        ans = sendrecv.sndrcv(self, *args, **kargs)[0]  # type: SndRcvList
254        if len(ans) > 0:
255            pkt = ans[0][1]  # type: Packet
256            return pkt
257        else:
258            return None
259
260    def sniff(self, *args, **kargs):
261        # type: (Any, Any) -> PacketList
262        from scapy import sendrecv
263        return sendrecv.sniff(opened_socket=self, *args, **kargs)
264
265    def tshark(self, *args, **kargs):
266        # type: (Any, Any) -> None
267        from scapy import sendrecv
268        sendrecv.tshark(opened_socket=self, *args, **kargs)
269
270    # TODO: use 'scapy.ansmachine.AnsweringMachine' when typed
271    def am(self,
272           cls,  # type: Type[Any]
273           *args,  # type: Any
274           **kwargs  # type: Any
275           ):
276        # type: (...) -> Any
277        """
278        Creates an AnsweringMachine associated with this socket.
279
280        :param cls: A subclass of AnsweringMachine to instantiate
281        """
282        return cls(*args, opened_socket=self, socket=self, **kwargs)
283
284    @staticmethod
285    def select(sockets, remain=conf.recv_poll_rate):
286        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
287        """This function is called during sendrecv() routine to select
288        the available sockets.
289
290        :param sockets: an array of sockets that need to be selected
291        :returns: an array of sockets that were selected and
292            the function to be called next to get the packets (i.g. recv)
293        """
294        try:
295            inp, _, _ = select(sockets, [], [], remain)
296        except (IOError, select_error) as exc:
297            # select.error has no .errno attribute
298            if not exc.args or exc.args[0] != errno.EINTR:
299                raise
300        return inp
301
302    def __del__(self):
303        # type: () -> None
304        """Close the socket"""
305        self.close()
306
307    def __enter__(self):
308        # type: () -> SuperSocket
309        return self
310
311    def __exit__(self, exc_type, exc_value, traceback):
312        # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None  # noqa: E501
313        """Close the socket"""
314        self.close()
315
316
317if not WINDOWS:
318    class L3RawSocket(SuperSocket):
319        desc = "Layer 3 using Raw sockets (PF_INET/SOCK_RAW)"
320
321        def __init__(self,
322                     type=ETH_P_IP,  # type: int
323                     filter=None,  # type: Optional[str]
324                     iface=None,  # type: Optional[_GlobInterfaceType]
325                     promisc=None,  # type: Optional[bool]
326                     nofilter=0  # type: int
327                     ):
328            # type: (...) -> None
329            self.outs = socket.socket(socket.AF_INET, socket.SOCK_RAW, socket.IPPROTO_RAW)  # noqa: E501
330            self.outs.setsockopt(socket.SOL_IP, socket.IP_HDRINCL, 1)
331            self.ins = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(type))  # noqa: E501
332            if iface is not None:
333                iface = network_name(iface)
334                self.iface = iface
335                self.ins.bind((iface, type))
336            else:
337                self.iface = "any"
338            try:
339                # Receive Auxiliary Data (VLAN tags)
340                self.ins.setsockopt(SOL_PACKET, PACKET_AUXDATA, 1)
341                self.ins.setsockopt(
342                    socket.SOL_SOCKET,
343                    SO_TIMESTAMPNS,
344                    1
345                )
346                self.auxdata_available = True
347            except OSError:
348                # Note: Auxiliary Data is only supported since
349                #       Linux 2.6.21
350                msg = "Your Linux Kernel does not support Auxiliary Data!"
351                log_runtime.info(msg)
352
353        def recv(self, x=MTU, **kwargs):
354            # type: (int, **Any) -> Optional[Packet]
355            data, sa_ll, ts = self._recv_raw(self.ins, x)
356            if sa_ll[2] == socket.PACKET_OUTGOING:
357                return None
358            if sa_ll[3] in conf.l2types:
359                cls = conf.l2types.num2layer[sa_ll[3]]  # type: Type[Packet]
360                lvl = 2
361            elif sa_ll[1] in conf.l3types:
362                cls = conf.l3types.num2layer[sa_ll[1]]
363                lvl = 3
364            else:
365                cls = conf.default_l2
366                warning("Unable to guess type (interface=%s protocol=%#x family=%i). Using %s", sa_ll[0], sa_ll[1], sa_ll[3], cls.name)  # noqa: E501
367                lvl = 3
368
369            try:
370                pkt = cls(data, **kwargs)
371            except KeyboardInterrupt:
372                raise
373            except Exception:
374                if conf.debug_dissector:
375                    raise
376                pkt = conf.raw_layer(data)
377
378            if lvl == 2:
379                pkt = pkt.payload
380
381            if pkt is not None:
382                if ts is None:
383                    from scapy.arch.linux import get_last_packet_timestamp
384                    ts = get_last_packet_timestamp(self.ins)
385                pkt.time = ts
386            return pkt
387
388        def send(self, x):
389            # type: (Packet) -> int
390            try:
391                sx = raw(x)
392                if self.outs:
393                    x.sent_time = time.time()
394                    return self.outs.sendto(
395                        sx,
396                        (x.dst, 0)
397                    )
398            except AttributeError:
399                raise ValueError(
400                    "Missing 'dst' attribute in the first layer to be "
401                    "sent using a native L3 socket ! (make sure you passed the "
402                    "IP layer)"
403                )
404            except socket.error as msg:
405                log_runtime.error(msg)
406            return 0
407
408    class L3RawSocket6(L3RawSocket):
409        def __init__(self,
410                     type: int = ETH_P_IPV6,
411                     filter: Optional[str] = None,
412                     iface: Optional[_GlobInterfaceType] = None,
413                     promisc: Optional[bool] = None,
414                     nofilter: bool = False) -> None:
415            # NOTE: if fragmentation is needed, it will be done by the kernel (RFC 2292)  # noqa: E501
416            self.outs = socket.socket(
417                socket.AF_INET6,
418                socket.SOCK_RAW,
419                socket.IPPROTO_RAW
420            )
421            self.ins = socket.socket(
422                socket.AF_PACKET,
423                socket.SOCK_RAW,
424                socket.htons(type)
425            )
426            self.iface = cast(_GlobInterfaceType, iface)
427
428
429class SimpleSocket(SuperSocket):
430    desc = "wrapper around a classic socket"
431    __selectable_force_select__ = True
432
433    def __init__(self, sock, basecls=None):
434        # type: (socket.socket, Optional[Type[Packet]]) -> None
435        self.ins = sock
436        self.outs = sock
437        if basecls is None:
438            basecls = conf.raw_layer
439        self.basecls = basecls
440
441    def recv_raw(self, x=MTU):
442        # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]]
443        return self.basecls, self.ins.recv(x), None
444
445    if WINDOWS:
446        @staticmethod
447        def select(sockets, remain=None):
448            # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
449            from scapy.automaton import select_objects
450            return select_objects(sockets, remain)
451
452
453class StreamSocket(SimpleSocket):
454    """
455    Wrap a stream socket into a layer 2 SuperSocket
456
457    :param sock: the socket to wrap
458    :param basecls: the base class packet to use to dissect the packet
459    """
460    desc = "transforms a stream socket into a layer 2"
461
462    def __init__(self,
463                 sock,  # type: socket.socket
464                 basecls=None,  # type: Optional[Type[Packet]]
465                 ):
466        # type: (...) -> None
467        from scapy.sessions import streamcls
468        self.rcvcls = streamcls(basecls or conf.raw_layer)
469        self.metadata: Dict[str, Any] = {}
470        self.streamsession: Dict[str, Any] = {}
471        self._buf = b""
472        super(StreamSocket, self).__init__(sock, basecls=basecls)
473
474    def recv(self, x=None, **kwargs):
475        # type: (Optional[int], Any) -> Optional[Packet]
476        if x is None:
477            x = MTU
478        # Block but in PEEK mode
479        data = self.ins.recv(x, socket.MSG_PEEK)
480        if data == b"":
481            raise EOFError
482        x = len(data)
483        pkt = self.rcvcls(self._buf + data, self.metadata, self.streamsession)
484        if pkt is None:  # Incomplete packet.
485            self._buf += self.ins.recv(x)
486            return self.recv(x)
487        self.metadata.clear()
488        # Strip any madding
489        pad = pkt.getlayer(conf.padding_layer)
490        if pad is not None and pad.underlayer is not None:
491            del pad.underlayer.payload
492        while pad is not None and not isinstance(pad, NoPayload):
493            x -= len(pad.load)
494            pad = pad.payload
495        # Only receive the packet length
496        self.ins.recv(x)
497        self._buf = b""
498        return pkt
499
500
501class SSLStreamSocket(StreamSocket):
502    desc = "similar usage than StreamSocket but specialized for handling SSL-wrapped sockets"  # noqa: E501
503
504    # Basically StreamSocket but we can't PEEK
505
506    def __init__(self, sock, basecls=None):
507        # type: (socket.socket, Optional[Type[Packet]]) -> None
508        from scapy.sessions import TCPSession
509        self.sess = TCPSession(app=True)
510        super(SSLStreamSocket, self).__init__(sock, basecls)
511
512    # 65535, the default value of x is the maximum length of a TLS record
513    def recv(self, x=None, **kwargs):
514        # type: (Optional[int], **Any) -> Optional[Packet]
515        if x is None:
516            x = MTU
517        # Block
518        try:
519            data = self.ins.recv(x)
520        except OSError:
521            raise EOFError
522        try:
523            pkt = self.sess.process(data, cls=self.basecls)  # type: ignore
524        except struct.error:
525            # Buffer underflow
526            pkt = None
527        if data == b"" and not pkt:
528            raise EOFError
529        if not pkt:
530            return self.recv(x)
531        return pkt
532
533    @staticmethod
534    def select(sockets, remain=None):
535        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
536        queued = [
537            x
538            for x in sockets
539            if isinstance(x, SSLStreamSocket) and x.sess.data
540        ]
541        if queued:
542            return queued  # type: ignore
543        return super(SSLStreamSocket, SSLStreamSocket).select(sockets, remain=remain)
544
545
546class L2ListenTcpdump(SuperSocket):
547    desc = "read packets at layer 2 using tcpdump"
548
549    def __init__(self,
550                 iface=None,  # type: Optional[_GlobInterfaceType]
551                 promisc=None,  # type: Optional[bool]
552                 filter=None,  # type: Optional[str]
553                 nofilter=False,  # type: bool
554                 prog=None,  # type: Optional[str]
555                 quiet=False,  # type: bool
556                 *arg,  # type: Any
557                 **karg  # type: Any
558                 ):
559        # type: (...) -> None
560        self.outs = None
561        args = ['-w', '-', '-s', '65535']
562        self.iface = "any"
563        if iface is None and (WINDOWS or DARWIN):
564            self.iface = iface = conf.iface
565        if promisc is None:
566            promisc = conf.sniff_promisc
567        if iface is not None:
568            args.extend(['-i', network_name(iface)])
569        if not promisc:
570            args.append('-p')
571        if not nofilter:
572            if conf.except_filter:
573                if filter:
574                    filter = "(%s) and not (%s)" % (filter, conf.except_filter)
575                else:
576                    filter = "not (%s)" % conf.except_filter
577        if filter is not None:
578            args.append(filter)
579        self.tcpdump_proc = tcpdump(
580            None, prog=prog, args=args, getproc=True, quiet=quiet)
581        self.reader = PcapReader(self.tcpdump_proc.stdout)
582        self.ins = self.reader  # type: ignore
583
584    def recv(self, x=MTU, **kwargs):
585        # type: (int, **Any) -> Optional[Packet]
586        return self.reader.recv(x, **kwargs)
587
588    def close(self):
589        # type: () -> None
590        SuperSocket.close(self)
591        self.tcpdump_proc.kill()
592
593    @staticmethod
594    def select(sockets, remain=None):
595        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
596        if (WINDOWS or DARWIN):
597            return sockets
598        return SuperSocket.select(sockets, remain=remain)
599
600
601# More abstract objects
602
603class IterSocket(SuperSocket):
604    desc = "wrapper around an iterable"
605    nonblocking_socket = True
606
607    def __init__(self, obj):
608        # type: (_PacketIterable) -> None
609        if not obj:
610            self.iter = iter([])  # type: Iterator[Packet]
611        elif isinstance(obj, IterSocket):
612            self.iter = obj.iter
613        elif isinstance(obj, SndRcvList):
614            def _iter(obj=cast(SndRcvList, obj)):
615                # type: (SndRcvList) -> Iterator[Packet]
616                for s, r in obj:
617                    if s.sent_time:
618                        s.time = s.sent_time
619                    yield s
620                    yield r
621            self.iter = _iter()
622        elif isinstance(obj, (list, PacketList)):
623            if isinstance(obj[0], bytes):
624                self.iter = iter(obj)
625            else:
626                self.iter = (y for x in obj for y in x)
627        else:
628            self.iter = obj.__iter__()
629
630    @staticmethod
631    def select(sockets, remain=None):
632        # type: (List[SuperSocket], Any) -> List[SuperSocket]
633        return sockets
634
635    def recv(self, x=None, **kwargs):
636        # type: (Optional[int], Any) -> Optional[Packet]
637        try:
638            pkt = next(self.iter)
639            return pkt.__class__(bytes(pkt), **kwargs)
640        except StopIteration:
641            raise EOFError
642
643    def close(self):
644        # type: () -> None
645        pass
646