1# SPDX-License-Identifier: GPL-2.0-only 2# This file is part of Scapy 3# See https://scapy.net/ for more information 4# Copyright (C) Nils Weiss <nils@we155.de> 5 6# scapy.contrib.description = TestSocket library for unit tests 7# scapy.contrib.status = library 8 9import time 10import random 11 12from threading import Lock 13 14from scapy.config import conf 15from scapy.automaton import ObjectPipe, select_objects 16from scapy.data import MTU 17from scapy.packet import Packet 18from scapy.error import Scapy_Exception 19 20# Typing imports 21from typing import ( 22 Optional, 23 Type, 24 Tuple, 25 Any, 26 List, 27) 28from scapy.supersocket import SuperSocket 29 30from scapy.plist import ( 31 PacketList, 32 SndRcvList, 33) 34 35 36open_test_sockets = list() # type: List[TestSocket] 37 38 39class TestSocket(SuperSocket): 40 41 test_socket_mutex = Lock() 42 43 def __init__(self, 44 basecls=None, # type: Optional[Type[Packet]] 45 external_obj_pipe=None # type: Optional[ObjectPipe[bytes]] 46 ): 47 # type: (...) -> None 48 global open_test_sockets 49 self.basecls = basecls 50 self.paired_sockets = list() # type: List[TestSocket] 51 self.ins = external_obj_pipe or ObjectPipe(name="TestSocket") # type: ignore 52 self._has_external_obj_pip = external_obj_pipe is not None 53 self.outs = None 54 open_test_sockets.append(self) 55 56 def __enter__(self): 57 # type: () -> TestSocket 58 return self 59 60 def __exit__(self, exc_type, exc_value, traceback): 61 # type: (Optional[Type[BaseException]], Optional[BaseException], Optional[Any]) -> None # noqa: E501 62 """Close the socket""" 63 self.close() 64 65 def sr(self, *args, **kargs): 66 # type: (Any, Any) -> Tuple[SndRcvList, PacketList] 67 """Send and Receive multiple packets 68 """ 69 from scapy import sendrecv 70 return sendrecv.sndrcv(self, *args, threaded=False, **kargs) 71 72 def sr1(self, *args, **kargs): 73 # type: (Any, Any) -> Optional[Packet] 74 """Send one packet and receive one answer 75 """ 76 from scapy import sendrecv 77 ans = sendrecv.sndrcv(self, *args, threaded=False, **kargs)[0] # type: SndRcvList 78 if len(ans) > 0: 79 pkt = ans[0][1] # type: Packet 80 return pkt 81 else: 82 return None 83 84 def close(self): 85 # type: () -> None 86 global open_test_sockets 87 88 if self.closed: 89 return 90 91 for s in self.paired_sockets: 92 try: 93 s.paired_sockets.remove(self) 94 except (ValueError, AttributeError, TypeError): 95 pass 96 97 if not self._has_external_obj_pip: 98 super(TestSocket, self).close() 99 else: 100 # We don't close external object pipes 101 self.closed = True 102 103 try: 104 open_test_sockets.remove(self) 105 except (ValueError, AttributeError, TypeError): 106 pass 107 108 def pair(self, sock): 109 # type: (TestSocket) -> None 110 self.paired_sockets += [sock] 111 sock.paired_sockets += [self] 112 113 def send(self, x): 114 # type: (Packet) -> int 115 sx = bytes(x) 116 for r in self.paired_sockets: 117 r.ins.send(sx) 118 try: 119 x.sent_time = time.time() 120 except AttributeError: 121 pass 122 return len(sx) 123 124 def recv_raw(self, x=MTU): 125 # type: (int) -> Tuple[Optional[Type[Packet]], Optional[bytes], Optional[float]] # noqa: E501 126 """Returns a tuple containing (cls, pkt_data, time)""" 127 return self.basecls, self.ins.recv(0), time.time() 128 129 @staticmethod 130 def select(sockets, remain=conf.recv_poll_rate): 131 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] 132 return select_objects(sockets, remain) 133 134 135class UnstableSocket(TestSocket): 136 """ 137 This is an unstable socket which randomly fires exceptions or loses 138 packets on recv. 139 """ 140 141 def __init__(self, 142 basecls=None, # type: Optional[Type[Packet]] 143 external_obj_pipe=None # type: Optional[ObjectPipe[bytes]] 144 ): 145 # type: (...) -> None 146 super(UnstableSocket, self).__init__(basecls, external_obj_pipe) 147 self.no_error_for_x_rx_pkts = 10 148 self.no_error_for_x_tx_pkts = 10 149 150 def send(self, x): 151 # type: (Packet) -> int 152 if self.no_error_for_x_tx_pkts == 0: 153 if random.randint(0, 1000) == 42: 154 self.no_error_for_x_tx_pkts = 10 155 print("SOCKET CLOSED") 156 raise OSError("Socket closed") 157 if self.no_error_for_x_tx_pkts > 0: 158 self.no_error_for_x_tx_pkts -= 1 159 return super(UnstableSocket, self).send(x) 160 161 def recv(self, x=MTU, **kwargs): 162 # type: (int, **Any) -> Optional[Packet] 163 if self.no_error_for_x_tx_pkts == 0: 164 if random.randint(0, 1000) == 42: 165 self.no_error_for_x_tx_pkts = 10 166 raise OSError("Socket closed") 167 if random.randint(0, 1000) == 13: 168 self.no_error_for_x_tx_pkts = 10 169 raise Scapy_Exception("Socket closed") 170 if random.randint(0, 1000) == 7: 171 self.no_error_for_x_tx_pkts = 10 172 raise ValueError("Socket closed") 173 if random.randint(0, 1000) == 113: 174 self.no_error_for_x_tx_pkts = 10 175 return None 176 if self.no_error_for_x_tx_pkts > 0: 177 self.no_error_for_x_tx_pkts -= 1 178 return super(UnstableSocket, self).recv(x, **kwargs) 179 180 181def cleanup_testsockets(): 182 # type: () -> None 183 """ 184 Helper function to remove TestSocket objects after a test 185 """ 186 count = max(len(open_test_sockets), 1) 187 while len(open_test_sockets) and count: 188 sock = open_test_sockets[0] 189 sock.close() 190 count -= 1 191