• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Nils Weiss <nils@we155.de>
5
6# scapy.contrib.description = TestSocket library for unit tests
7# scapy.contrib.status = library
8
9import time
10import random
11
12from threading import Lock
13
14from scapy.config import conf
15from scapy.automaton import ObjectPipe, select_objects
16from scapy.data import MTU
17from scapy.packet import Packet
18from scapy.error import Scapy_Exception
19
20# Typing imports
21from typing import (
22    Optional,
23    Type,
24    Tuple,
25    Any,
26    List,
27)
28from scapy.supersocket import SuperSocket
29
30from scapy.plist import (
31    PacketList,
32    SndRcvList,
33)
34
35
36open_test_sockets = list()  # type: List[TestSocket]
37
38
39class TestSocket(SuperSocket):
40
41    test_socket_mutex = Lock()
42
43    def __init__(self,
44                 basecls=None,  # type: Optional[Type[Packet]]
45                 external_obj_pipe=None  # type: Optional[ObjectPipe[bytes]]
46                 ):
47        # type: (...) -> None
48        global open_test_sockets
49        self.basecls = basecls
50        self.paired_sockets = list()  # type: List[TestSocket]
51        self.ins = external_obj_pipe or ObjectPipe(name="TestSocket")  # type: ignore
52        self._has_external_obj_pip = external_obj_pipe is not None
53        self.outs = None
54        open_test_sockets.append(self)
55
56    def __enter__(self):
57        # type: () -> TestSocket
58        return self
59
60    def __exit__(self, exc_type, exc_value, traceback):
61        # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None  # noqa: E501
62        """Close the socket"""
63        self.close()
64
65    def sr(self, *args, **kargs):
66        # type: (Any, Any) -> Tuple[SndRcvList, PacketList]
67        """Send and Receive multiple packets
68        """
69        from scapy import sendrecv
70        return sendrecv.sndrcv(self, *args, threaded=False, **kargs)
71
72    def sr1(self, *args, **kargs):
73        # type: (Any, Any) -> Optional[Packet]
74        """Send one packet and receive one answer
75        """
76        from scapy import sendrecv
77        ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0]  # type: SndRcvList
78        if len(ans) > 0:
79            pkt = ans[0][1]  # type: Packet
80            return pkt
81        else:
82            return None
83
84    def close(self):
85        # type: () -> None
86        global open_test_sockets
87
88        if self.closed:
89            return
90
91        for s in self.paired_sockets:
92            try:
93                s.paired_sockets.remove(self)
94            except (ValueError, AttributeError, TypeError):
95                pass
96
97        if not self._has_external_obj_pip:
98            super(TestSocket, self).close()
99        else:
100            # We don't close external object pipes
101            self.closed = True
102
103        try:
104            open_test_sockets.remove(self)
105        except (ValueError, AttributeError, TypeError):
106            pass
107
108    def pair(self, sock):
109        # type: (TestSocket) -> None
110        self.paired_sockets += [sock]
111        sock.paired_sockets += [self]
112
113    def send(self, x):
114        # type: (Packet) -> int
115        sx = bytes(x)
116        for r in self.paired_sockets:
117            r.ins.send(sx)
118        try:
119            x.sent_time = time.time()
120        except AttributeError:
121            pass
122        return len(sx)
123
124    def recv_raw(self, x=MTU):
125        # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]]  # noqa: E501
126        """Returns a tuple containing (cls, pkt_data, time)"""
127        return self.basecls, self.ins.recv(0), time.time()
128
129    @staticmethod
130    def select(sockets, remain=conf.recv_poll_rate):
131        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
132        return select_objects(sockets, remain)
133
134
135class UnstableSocket(TestSocket):
136    """
137    This is an unstable socket which randomly fires exceptions or loses
138    packets on recv.
139    """
140
141    def __init__(self,
142                 basecls=None,  # type: Optional[Type[Packet]]
143                 external_obj_pipe=None  # type: Optional[ObjectPipe[bytes]]
144                 ):
145        # type: (...) -> None
146        super(UnstableSocket, self).__init__(basecls, external_obj_pipe)
147        self.no_error_for_x_rx_pkts = 10
148        self.no_error_for_x_tx_pkts = 10
149
150    def send(self, x):
151        # type: (Packet) -> int
152        if self.no_error_for_x_tx_pkts == 0:
153            if random.randint(0, 1000) == 42:
154                self.no_error_for_x_tx_pkts = 10
155                print("SOCKET CLOSED")
156                raise OSError("Socket closed")
157        if self.no_error_for_x_tx_pkts > 0:
158            self.no_error_for_x_tx_pkts -= 1
159        return super(UnstableSocket, self).send(x)
160
161    def recv(self, x=MTU, **kwargs):
162        # type: (int, **Any) -> Optional[Packet]
163        if self.no_error_for_x_tx_pkts == 0:
164            if random.randint(0, 1000) == 42:
165                self.no_error_for_x_tx_pkts = 10
166                raise OSError("Socket closed")
167            if random.randint(0, 1000) == 13:
168                self.no_error_for_x_tx_pkts = 10
169                raise Scapy_Exception("Socket closed")
170            if random.randint(0, 1000) == 7:
171                self.no_error_for_x_tx_pkts = 10
172                raise ValueError("Socket closed")
173            if random.randint(0, 1000) == 113:
174                self.no_error_for_x_tx_pkts = 10
175                return None
176        if self.no_error_for_x_tx_pkts > 0:
177            self.no_error_for_x_tx_pkts -= 1
178        return super(UnstableSocket, self).recv(x, **kwargs)
179
180
181def cleanup_testsockets():
182    # type: () -> None
183    """
184    Helper function to remove TestSocket objects after a test
185    """
186    count = max(len(open_test_sockets), 1)
187    while len(open_test_sockets) and count:
188        sock = open_test_sockets[0]
189        sock.close()
190        count -= 1
191