1# SPDX-License-Identifier: GPL-2.0-only 2# This file is part of Scapy 3# See https://scapy.net/ for more information 4# Copyright (C) Philippe Biondi <phil@secdev.org> 5# Copyright (C) Gabriel Potter 6 7""" 8Automata with states, transitions and actions. 9 10TODO: 11 - add documentation for ioevent, as_supersocket... 12""" 13 14import ctypes 15import itertools 16import logging 17import os 18import random 19import socket 20import sys 21import threading 22import time 23import traceback 24import types 25 26import select 27from collections import deque 28 29from scapy.config import conf 30from scapy.consts import WINDOWS 31from scapy.data import MTU 32from scapy.error import log_runtime, warning 33from scapy.interfaces import _GlobInterfaceType 34from scapy.packet import Packet 35from scapy.plist import PacketList 36from scapy.supersocket import SuperSocket, StreamSocket 37from scapy.utils import do_graph 38 39# Typing imports 40from typing import ( 41 Any, 42 Callable, 43 Deque, 44 Dict, 45 Generic, 46 Iterable, 47 Iterator, 48 List, 49 Optional, 50 Set, 51 Tuple, 52 Type, 53 TypeVar, 54 Union, 55 cast, 56) 57from scapy.compat import DecoratorCallable 58 59 60# winsock.h 61FD_READ = 0x00000001 62 63 64def select_objects(inputs, remain): 65 # type: (Iterable[Any], Union[float, int, None]) -> List[Any] 66 """ 67 Select objects. Same than: 68 ``select.select(inputs, [], [], remain)`` 69 70 But also works on Windows, only on objects whose fileno() returns 71 a Windows event. For simplicity, just use `ObjectPipe()` as a queue 72 that you can select on whatever the platform is. 73 74 If you want an object to be always included in the output of 75 select_objects (i.e. it's not selectable), just make fileno() 76 return a strictly negative value. 77 78 Example: 79 80 >>> a, b = ObjectPipe("a"), ObjectPipe("b") 81 >>> b.send("test") 82 >>> select_objects([a, b], 1) 83 [b] 84 85 :param inputs: objects to process 86 :param remain: timeout. If 0, poll. If None, block. 87 """ 88 if not WINDOWS: 89 return select.select(inputs, [], [], remain)[0] 90 inputs = list(inputs) 91 events = [] 92 created = [] 93 results = set() 94 for i in inputs: 95 if getattr(i, "__selectable_force_select__", False): 96 # Native socket.socket object. We would normally use select.select. 97 evt = ctypes.windll.ws2_32.WSACreateEvent() 98 created.append(evt) 99 res = ctypes.windll.ws2_32.WSAEventSelect( 100 ctypes.c_void_p(i.fileno()), 101 evt, 102 FD_READ 103 ) 104 if res == 0: 105 # Was a socket 106 events.append(evt) 107 else: 108 # Fallback to normal event 109 events.append(i.fileno()) 110 elif i.fileno() < 0: 111 # Special case: On Windows, we consider that an object that returns 112 # a negative fileno (impossible), is always readable. This is used 113 # in very few places but important (e.g. PcapReader), where we have 114 # no valid fileno (and will stop on EOFError). 115 results.add(i) 116 remain = 0 117 else: 118 events.append(i.fileno()) 119 if events: 120 # 0xFFFFFFFF = INFINITE 121 remainms = int(remain * 1000 if remain is not None else 0xFFFFFFFF) 122 if len(events) == 1: 123 res = ctypes.windll.kernel32.WaitForSingleObject( 124 ctypes.c_void_p(events[0]), 125 remainms 126 ) 127 else: 128 # Sadly, the only way to emulate select() is to first check 129 # if any object is available using WaitForMultipleObjects 130 # then poll the others. 131 res = ctypes.windll.kernel32.WaitForMultipleObjects( 132 len(events), 133 (ctypes.c_void_p * len(events))( 134 *events 135 ), 136 False, 137 remainms 138 ) 139 if res != 0xFFFFFFFF and res != 0x00000102: # Failed or Timeout 140 results.add(inputs[res]) 141 if len(events) > 1: 142 # Now poll the others, if any 143 for i, evt in enumerate(events): 144 res = ctypes.windll.kernel32.WaitForSingleObject( 145 ctypes.c_void_p(evt), 146 0 # poll: don't wait 147 ) 148 if res == 0: 149 results.add(inputs[i]) 150 # Cleanup created events, if any 151 for evt in created: 152 ctypes.windll.ws2_32.WSACloseEvent(evt) 153 return list(results) 154 155 156_T = TypeVar("_T") 157 158 159class ObjectPipe(Generic[_T]): 160 def __init__(self, name=None): 161 # type: (Optional[str]) -> None 162 self.name = name or "ObjectPipe" 163 self.closed = False 164 self.__rd, self.__wr = os.pipe() 165 self.__queue = deque() # type: Deque[_T] 166 if WINDOWS: 167 self._wincreate() 168 169 if WINDOWS: 170 def _wincreate(self): 171 # type: () -> None 172 self._fd = cast(int, ctypes.windll.kernel32.CreateEventA( 173 None, True, False, 174 ctypes.create_string_buffer(b"ObjectPipe %f" % random.random()) 175 )) 176 177 def _winset(self): 178 # type: () -> None 179 if ctypes.windll.kernel32.SetEvent(ctypes.c_void_p(self._fd)) == 0: 180 warning(ctypes.FormatError(ctypes.GetLastError())) 181 182 def _winreset(self): 183 # type: () -> None 184 if ctypes.windll.kernel32.ResetEvent(ctypes.c_void_p(self._fd)) == 0: 185 warning(ctypes.FormatError(ctypes.GetLastError())) 186 187 def _winclose(self): 188 # type: () -> None 189 if ctypes.windll.kernel32.CloseHandle(ctypes.c_void_p(self._fd)) == 0: 190 warning(ctypes.FormatError(ctypes.GetLastError())) 191 192 def fileno(self): 193 # type: () -> int 194 if WINDOWS: 195 return self._fd 196 return self.__rd 197 198 def send(self, obj): 199 # type: (_T) -> int 200 self.__queue.append(obj) 201 if WINDOWS: 202 self._winset() 203 os.write(self.__wr, b"X") 204 return 1 205 206 def write(self, obj): 207 # type: (_T) -> None 208 self.send(obj) 209 210 def empty(self): 211 # type: () -> bool 212 return not bool(self.__queue) 213 214 def flush(self): 215 # type: () -> None 216 pass 217 218 def recv(self, n=0, options=socket.MsgFlag(0)): 219 # type: (Optional[int], socket.MsgFlag) -> Optional[_T] 220 if self.closed: 221 raise EOFError 222 if options & socket.MSG_PEEK: 223 if self.__queue: 224 return self.__queue[0] 225 return None 226 os.read(self.__rd, 1) 227 elt = self.__queue.popleft() 228 if WINDOWS and not self.__queue: 229 self._winreset() 230 return elt 231 232 def read(self, n=0): 233 # type: (Optional[int]) -> Optional[_T] 234 return self.recv(n) 235 236 def clear(self): 237 # type: () -> None 238 if not self.closed: 239 while not self.empty(): 240 self.recv() 241 242 def close(self): 243 # type: () -> None 244 if not self.closed: 245 self.closed = True 246 os.close(self.__rd) 247 os.close(self.__wr) 248 if WINDOWS: 249 try: 250 self._winclose() 251 except ImportError: 252 # Python is shutting down 253 pass 254 255 def __repr__(self): 256 # type: () -> str 257 return "<%s at %s>" % (self.name, id(self)) 258 259 def __del__(self): 260 # type: () -> None 261 self.close() 262 263 @staticmethod 264 def select(sockets, remain=conf.recv_poll_rate): 265 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] 266 # Only handle ObjectPipes 267 results = [] 268 for s in sockets: 269 if s.closed: # allow read to trigger EOF 270 results.append(s) 271 if results: 272 return results 273 return select_objects(sockets, remain) 274 275 276class Message: 277 type = None # type: str 278 pkt = None # type: Packet 279 result = None # type: str 280 state = None # type: Message 281 exc_info = None # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501 282 283 def __init__(self, **args): 284 # type: (Any) -> None 285 self.__dict__.update(args) 286 287 def __repr__(self): 288 # type: () -> str 289 return "<Message %s>" % " ".join( 290 "%s=%r" % (k, v) 291 for k, v in self.__dict__.items() 292 if not k.startswith("_") 293 ) 294 295 296class Timer(): 297 def __init__(self, time, prio=0, autoreload=False): 298 # type: (Union[int, float], int, bool) -> None 299 self._timeout = float(time) # type: float 300 self._time = 0 # type: float 301 self._just_expired = True 302 self._expired = True 303 self._prio = prio 304 self._func = _StateWrapper() 305 self._autoreload = autoreload 306 307 def get(self): 308 # type: () -> float 309 return self._timeout 310 311 def set(self, val): 312 # type: (float) -> None 313 self._timeout = val 314 315 def _reset(self): 316 # type: () -> None 317 self._time = self._timeout 318 self._expired = False 319 self._just_expired = False 320 321 def _reset_just_expired(self): 322 # type: () -> None 323 self._just_expired = False 324 325 def _running(self): 326 # type: () -> bool 327 return self._time > 0 328 329 def _remaining(self): 330 # type: () -> float 331 return max(self._time, 0) 332 333 def _decrement(self, time): 334 # type: (float) -> None 335 self._time -= time 336 if self._time <= 0: 337 if not self._expired: 338 self._just_expired = True 339 if self._autoreload: 340 # take overshoot into account 341 self._time = self._timeout + self._time 342 else: 343 self._expired = True 344 self._time = 0 345 346 def __lt__(self, obj): 347 # type: (Timer) -> bool 348 return ((self._time < obj._time) if self._time != obj._time 349 else (self._prio < obj._prio)) 350 351 def __gt__(self, obj): 352 # type: (Timer) -> bool 353 return ((self._time > obj._time) if self._time != obj._time 354 else (self._prio > obj._prio)) 355 356 def __eq__(self, obj): 357 # type: (Any) -> bool 358 if not isinstance(obj, Timer): 359 raise NotImplementedError() 360 return (self._time == obj._time) and (self._prio == obj._prio) 361 362 def __repr__(self): 363 # type: () -> str 364 return "<Timer %f(%f)>" % (self._time, self._timeout) 365 366 367class _TimerList(): 368 def __init__(self): 369 # type: () -> None 370 self.timers = [] # type: list[Timer] 371 372 def add_timer(self, timer): 373 # type: (Timer) -> None 374 self.timers.append(timer) 375 376 def reset(self): 377 # type: () -> None 378 for t in self.timers: 379 t._reset() 380 381 def decrement(self, time): 382 # type: (float) -> None 383 for t in self.timers: 384 t._decrement(time) 385 386 def expired(self): 387 # type: () -> list[Timer] 388 lst = [t for t in self.timers if t._just_expired] 389 lst.sort(key=lambda x: x._prio, reverse=True) 390 for t in lst: 391 t._reset_just_expired() 392 return lst 393 394 def until_next(self): 395 # type: () -> Optional[float] 396 try: 397 return min([t._remaining() for t in self.timers if t._running()]) 398 except ValueError: 399 return None # None means blocking 400 401 def count(self): 402 # type: () -> int 403 return len(self.timers) 404 405 def __iter__(self): 406 # type: () -> Iterator[Timer] 407 return self.timers.__iter__() 408 409 def __repr__(self): 410 # type: () -> str 411 return self.timers.__repr__() 412 413 414class _instance_state: 415 def __init__(self, instance): 416 # type: (Any) -> None 417 self.__self__ = instance.__self__ 418 self.__func__ = instance.__func__ 419 self.__self__.__class__ = instance.__self__.__class__ 420 421 def __getattr__(self, attr): 422 # type: (str) -> Any 423 return getattr(self.__func__, attr) 424 425 def __call__(self, *args, **kargs): 426 # type: (Any, Any) -> Any 427 return self.__func__(self.__self__, *args, **kargs) 428 429 def breaks(self): 430 # type: () -> Any 431 return self.__self__.add_breakpoints(self.__func__) 432 433 def intercepts(self): 434 # type: () -> Any 435 return self.__self__.add_interception_points(self.__func__) 436 437 def unbreaks(self): 438 # type: () -> Any 439 return self.__self__.remove_breakpoints(self.__func__) 440 441 def unintercepts(self): 442 # type: () -> Any 443 return self.__self__.remove_interception_points(self.__func__) 444 445 446############## 447# Automata # 448############## 449 450class _StateWrapper: 451 __name__ = None # type: str 452 atmt_type = None # type: str 453 atmt_state = None # type: str 454 atmt_initial = None # type: int 455 atmt_final = None # type: int 456 atmt_stop = None # type: int 457 atmt_error = None # type: int 458 atmt_origfunc = None # type: _StateWrapper 459 atmt_prio = None # type: int 460 atmt_as_supersocket = None # type: Optional[str] 461 atmt_condname = None # type: str 462 atmt_ioname = None # type: str 463 atmt_timeout = None # type: Timer 464 atmt_cond = None # type: Dict[str, int] 465 __code__ = None # type: types.CodeType 466 __call__ = None # type: Callable[..., ATMT.NewStateRequested] 467 468 469class ATMT: 470 STATE = "State" 471 ACTION = "Action" 472 CONDITION = "Condition" 473 RECV = "Receive condition" 474 TIMEOUT = "Timeout condition" 475 EOF = "EOF condition" 476 IOEVENT = "I/O event" 477 478 class NewStateRequested(Exception): 479 def __init__(self, state_func, automaton, *args, **kargs): 480 # type: (Any, ATMT, Any, Any) -> None 481 self.func = state_func 482 self.state = state_func.atmt_state 483 self.initial = state_func.atmt_initial 484 self.error = state_func.atmt_error 485 self.stop = state_func.atmt_stop 486 self.final = state_func.atmt_final 487 Exception.__init__(self, "Request state [%s]" % self.state) 488 self.automaton = automaton 489 self.args = args 490 self.kargs = kargs 491 self.action_parameters() # init action parameters 492 493 def action_parameters(self, *args, **kargs): 494 # type: (Any, Any) -> ATMT.NewStateRequested 495 self.action_args = args 496 self.action_kargs = kargs 497 return self 498 499 def run(self): 500 # type: () -> Any 501 return self.func(self.automaton, *self.args, **self.kargs) 502 503 def __repr__(self): 504 # type: () -> str 505 return "NewStateRequested(%s)" % self.state 506 507 @staticmethod 508 def state(initial=0, # type: int 509 final=0, # type: int 510 stop=0, # type: int 511 error=0 # type: int 512 ): 513 # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable] 514 def deco(f, initial=initial, final=final): 515 # type: (_StateWrapper, int, int) -> _StateWrapper 516 f.atmt_type = ATMT.STATE 517 f.atmt_state = f.__name__ 518 f.atmt_initial = initial 519 f.atmt_final = final 520 f.atmt_stop = stop 521 f.atmt_error = error 522 523 def _state_wrapper(self, *args, **kargs): 524 # type: (ATMT, Any, Any) -> ATMT.NewStateRequested 525 return ATMT.NewStateRequested(f, self, *args, **kargs) 526 527 state_wrapper = cast(_StateWrapper, _state_wrapper) 528 state_wrapper.__name__ = "%s_wrapper" % f.__name__ 529 state_wrapper.atmt_type = ATMT.STATE 530 state_wrapper.atmt_state = f.__name__ 531 state_wrapper.atmt_initial = initial 532 state_wrapper.atmt_final = final 533 state_wrapper.atmt_stop = stop 534 state_wrapper.atmt_error = error 535 state_wrapper.atmt_origfunc = f 536 return state_wrapper 537 return deco # type: ignore 538 539 @staticmethod 540 def action(cond, prio=0): 541 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 542 def deco(f, cond=cond): 543 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 544 if not hasattr(f, "atmt_type"): 545 f.atmt_cond = {} 546 f.atmt_type = ATMT.ACTION 547 f.atmt_cond[cond.atmt_condname] = prio 548 return f 549 return deco 550 551 @staticmethod 552 def condition(state, prio=0): 553 # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 554 def deco(f, state=state): 555 # type: (_StateWrapper, _StateWrapper) -> Any 556 f.atmt_type = ATMT.CONDITION 557 f.atmt_state = state.atmt_state 558 f.atmt_condname = f.__name__ 559 f.atmt_prio = prio 560 return f 561 return deco 562 563 @staticmethod 564 def receive_condition(state, prio=0): 565 # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 566 def deco(f, state=state): 567 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 568 f.atmt_type = ATMT.RECV 569 f.atmt_state = state.atmt_state 570 f.atmt_condname = f.__name__ 571 f.atmt_prio = prio 572 return f 573 return deco 574 575 @staticmethod 576 def ioevent(state, # type: _StateWrapper 577 name, # type: str 578 prio=0, # type: int 579 as_supersocket=None # type: Optional[str] 580 ): 581 # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 582 def deco(f, state=state): 583 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 584 f.atmt_type = ATMT.IOEVENT 585 f.atmt_state = state.atmt_state 586 f.atmt_condname = f.__name__ 587 f.atmt_ioname = name 588 f.atmt_prio = prio 589 f.atmt_as_supersocket = as_supersocket 590 return f 591 return deco 592 593 @staticmethod 594 def timeout(state, timeout): 595 # type: (_StateWrapper, Union[int, float]) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501 596 def deco(f, state=state, timeout=Timer(timeout)): 597 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper 598 f.atmt_type = ATMT.TIMEOUT 599 f.atmt_state = state.atmt_state 600 f.atmt_timeout = timeout 601 f.atmt_timeout._func = f 602 f.atmt_condname = f.__name__ 603 return f 604 return deco 605 606 @staticmethod 607 def timer(state, timeout, prio=0): 608 # type: (_StateWrapper, Union[float, int], int) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper] # noqa: E501 609 def deco(f, state=state, timeout=Timer(timeout, prio=prio, autoreload=True)): 610 # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper 611 f.atmt_type = ATMT.TIMEOUT 612 f.atmt_state = state.atmt_state 613 f.atmt_timeout = timeout 614 f.atmt_timeout._func = f 615 f.atmt_condname = f.__name__ 616 return f 617 return deco 618 619 @staticmethod 620 def eof(state): 621 # type: (_StateWrapper) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper] # noqa: E501 622 def deco(f, state=state): 623 # type: (_StateWrapper, _StateWrapper) -> _StateWrapper 624 f.atmt_type = ATMT.EOF 625 f.atmt_state = state.atmt_state 626 f.atmt_condname = f.__name__ 627 return f 628 return deco 629 630 631class _ATMT_Command: 632 RUN = "RUN" 633 NEXT = "NEXT" 634 FREEZE = "FREEZE" 635 STOP = "STOP" 636 FORCESTOP = "FORCESTOP" 637 END = "END" 638 EXCEPTION = "EXCEPTION" 639 SINGLESTEP = "SINGLESTEP" 640 BREAKPOINT = "BREAKPOINT" 641 INTERCEPT = "INTERCEPT" 642 ACCEPT = "ACCEPT" 643 REPLACE = "REPLACE" 644 REJECT = "REJECT" 645 646 647class _ATMT_supersocket(SuperSocket): 648 def __init__(self, 649 name, # type: str 650 ioevent, # type: str 651 automaton, # type: Type[Automaton] 652 proto, # type: Callable[[bytes], Any] 653 *args, # type: Any 654 **kargs # type: Any 655 ): 656 # type: (...) -> None 657 self.name = name 658 self.ioevent = ioevent 659 self.proto = proto 660 # write, read 661 self.spa, self.spb = ObjectPipe[Any]("spa"), \ 662 ObjectPipe[Any]("spb") 663 kargs["external_fd"] = {ioevent: (self.spa, self.spb)} 664 kargs["is_atmt_socket"] = True 665 kargs["atmt_socket"] = self.name 666 self.atmt = automaton(*args, **kargs) 667 self.atmt.runbg() 668 669 def send(self, s): 670 # type: (Any) -> int 671 return self.spa.send(s) 672 673 def fileno(self): 674 # type: () -> int 675 return self.spb.fileno() 676 677 # note: _ATMT_supersocket may return bytes in certain cases, which 678 # is expected. We cheat on typing. 679 def recv(self, n=MTU, **kwargs): # type: ignore 680 # type: (int, **Any) -> Any 681 r = self.spb.recv(n) 682 if self.proto is not None and r is not None: 683 r = self.proto(r, **kwargs) 684 return r 685 686 def close(self): 687 # type: () -> None 688 if not self.closed: 689 self.atmt.stop() 690 self.atmt.destroy() 691 self.spa.close() 692 self.spb.close() 693 self.closed = True 694 695 @staticmethod 696 def select(sockets, remain=conf.recv_poll_rate): 697 # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket] 698 return select_objects(sockets, remain) 699 700 701class _ATMT_to_supersocket: 702 def __init__(self, name, ioevent, automaton): 703 # type: (str, str, Type[Automaton]) -> None 704 self.name = name 705 self.ioevent = ioevent 706 self.automaton = automaton 707 708 def __call__(self, proto, *args, **kargs): 709 # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket 710 return _ATMT_supersocket( 711 self.name, self.ioevent, self.automaton, 712 proto, *args, **kargs 713 ) 714 715 716class Automaton_metaclass(type): 717 def __new__(cls, name, bases, dct): 718 # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton] 719 cls = super(Automaton_metaclass, cls).__new__( # type: ignore 720 cls, name, bases, dct 721 ) 722 cls.states = {} 723 cls.recv_conditions = {} # type: Dict[str, List[_StateWrapper]] 724 cls.conditions = {} # type: Dict[str, List[_StateWrapper]] 725 cls.ioevents = {} # type: Dict[str, List[_StateWrapper]] 726 cls.timeout = {} # type: Dict[str, _TimerList] 727 cls.eofs = {} # type: Dict[str, _StateWrapper] 728 cls.actions = {} # type: Dict[str, List[_StateWrapper]] 729 cls.initial_states = [] # type: List[_StateWrapper] 730 cls.stop_state = None # type: Optional[_StateWrapper] 731 cls.ionames = [] 732 cls.iosupersockets = [] 733 734 members = {} 735 classes = [cls] 736 while classes: 737 c = classes.pop(0) # order is important to avoid breaking method overloading # noqa: E501 738 classes += list(c.__bases__) 739 for k, v in c.__dict__.items(): # type: ignore 740 if k not in members: 741 members[k] = v 742 743 decorated = [v for v in members.values() 744 if hasattr(v, "atmt_type")] 745 746 for m in decorated: 747 if m.atmt_type == ATMT.STATE: 748 s = m.atmt_state 749 cls.states[s] = m 750 cls.recv_conditions[s] = [] 751 cls.ioevents[s] = [] 752 cls.conditions[s] = [] 753 cls.timeout[s] = _TimerList() 754 if m.atmt_initial: 755 cls.initial_states.append(m) 756 if m.atmt_stop: 757 if cls.stop_state is not None: 758 raise ValueError("There can only be a single stop state !") 759 cls.stop_state = m 760 elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT, ATMT.EOF]: # noqa: E501 761 cls.actions[m.atmt_condname] = [] 762 763 for m in decorated: 764 if m.atmt_type == ATMT.CONDITION: 765 cls.conditions[m.atmt_state].append(m) 766 elif m.atmt_type == ATMT.RECV: 767 cls.recv_conditions[m.atmt_state].append(m) 768 elif m.atmt_type == ATMT.EOF: 769 cls.eofs[m.atmt_state] = m 770 elif m.atmt_type == ATMT.IOEVENT: 771 cls.ioevents[m.atmt_state].append(m) 772 cls.ionames.append(m.atmt_ioname) 773 if m.atmt_as_supersocket is not None: 774 cls.iosupersockets.append(m) 775 elif m.atmt_type == ATMT.TIMEOUT: 776 cls.timeout[m.atmt_state].add_timer(m.atmt_timeout) 777 elif m.atmt_type == ATMT.ACTION: 778 for co in m.atmt_cond: 779 cls.actions[co].append(m) 780 781 for v in itertools.chain( 782 cls.conditions.values(), 783 cls.recv_conditions.values(), 784 cls.ioevents.values() 785 ): 786 v.sort(key=lambda x: x.atmt_prio) 787 for condname, actlst in cls.actions.items(): 788 actlst.sort(key=lambda x: x.atmt_cond[condname]) 789 790 for ioev in cls.iosupersockets: 791 setattr(cls, ioev.atmt_as_supersocket, 792 _ATMT_to_supersocket( 793 ioev.atmt_as_supersocket, 794 ioev.atmt_ioname, 795 cast(Type["Automaton"], cls))) 796 797 # Inject signature 798 try: 799 import inspect 800 cls.__signature__ = inspect.signature(cls.parse_args) # type: ignore # noqa: E501 801 except (ImportError, AttributeError): 802 pass 803 804 return cast(Type["Automaton"], cls) 805 806 def build_graph(self): 807 # type: () -> str 808 s = 'digraph "%s" {\n' % self.__class__.__name__ 809 810 se = "" # Keep initial nodes at the beginning for better rendering 811 for st in self.states.values(): 812 if st.atmt_initial: 813 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se # noqa: E501 814 elif st.atmt_final: 815 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state # noqa: E501 816 elif st.atmt_error: 817 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state # noqa: E501 818 elif st.atmt_stop: 819 se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state # noqa: E501 820 s += se 821 822 for st in self.states.values(): 823 names = list( 824 st.atmt_origfunc.__code__.co_names + 825 st.atmt_origfunc.__code__.co_consts 826 ) 827 while names: 828 n = names.pop() 829 if n in self.states: 830 s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n) 831 elif n in self.__dict__: 832 # function indirection 833 if callable(self.__dict__[n]): 834 names.extend(self.__dict__[n].__code__.co_names) 835 names.extend(self.__dict__[n].__code__.co_consts) 836 837 for c, sty, k, v in ( 838 [("purple", "solid", k, v) for k, v in self.conditions.items()] + 839 [("red", "solid", k, v) for k, v in self.recv_conditions.items()] + 840 [("orange", "solid", k, v) for k, v in self.ioevents.items()] + 841 [("black", "dashed", k, [v]) for k, v in self.eofs.items()] 842 ): 843 for f in v: 844 names = list(f.__code__.co_names + f.__code__.co_consts) 845 while names: 846 n = names.pop() 847 if n in self.states: 848 line = f.atmt_condname 849 for x in self.actions[f.atmt_condname]: 850 line += "\\l>[%s]" % x.__name__ 851 s += '\t"%s" -> "%s" [label="%s", color=%s, style=%s];\n' % ( 852 k, 853 n, 854 line, 855 c, 856 sty, 857 ) 858 elif n in self.__dict__: 859 # function indirection 860 if callable(self.__dict__[n]) and hasattr(self.__dict__[n], "__code__"): # noqa: E501 861 names.extend(self.__dict__[n].__code__.co_names) 862 names.extend(self.__dict__[n].__code__.co_consts) 863 for k, timers in self.timeout.items(): 864 for timer in timers: 865 for n in (timer._func.__code__.co_names + 866 timer._func.__code__.co_consts): 867 if n in self.states: 868 line = "%s/%.1fs" % (timer._func.atmt_condname, 869 timer.get()) 870 for x in self.actions[timer._func.atmt_condname]: 871 line += "\\l>[%s]" % x.__name__ 872 s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line) # noqa: E501 873 s += "}\n" 874 return s 875 876 def graph(self, **kargs): 877 # type: (Any) -> Optional[str] 878 s = self.build_graph() 879 return do_graph(s, **kargs) 880 881 882class Automaton(metaclass=Automaton_metaclass): 883 states = {} # type: Dict[str, _StateWrapper] 884 state = None # type: ATMT.NewStateRequested 885 recv_conditions = {} # type: Dict[str, List[_StateWrapper]] 886 conditions = {} # type: Dict[str, List[_StateWrapper]] 887 eofs = {} # type: Dict[str, _StateWrapper] 888 ioevents = {} # type: Dict[str, List[_StateWrapper]] 889 timeout = {} # type: Dict[str, _TimerList] 890 actions = {} # type: Dict[str, List[_StateWrapper]] 891 initial_states = [] # type: List[_StateWrapper] 892 stop_state = None # type: Optional[_StateWrapper] 893 ionames = [] # type: List[str] 894 iosupersockets = [] # type: List[SuperSocket] 895 896 # used for spawn() 897 pkt_cls = conf.raw_layer 898 socketcls = StreamSocket 899 900 # Internals 901 def __init__(self, *args, **kargs): 902 # type: (Any, Any) -> None 903 external_fd = kargs.pop("external_fd", {}) 904 if "sock" in kargs: 905 # We use a bi-directional sock 906 self.sock = kargs["sock"] 907 else: 908 # Separate sockets 909 self.sock = None 910 self.send_sock_class = kargs.pop("ll", conf.L3socket) 911 self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) 912 self.listen_sock = None # type: Optional[SuperSocket] 913 self.send_sock = None # type: Optional[SuperSocket] 914 self.is_atmt_socket = kargs.pop("is_atmt_socket", False) 915 self.atmt_socket = kargs.pop("atmt_socket", None) 916 self.started = threading.Lock() 917 self.threadid = None # type: Optional[int] 918 self.breakpointed = None 919 self.breakpoints = set() # type: Set[_StateWrapper] 920 self.interception_points = set() # type: Set[_StateWrapper] 921 self.intercepted_packet = None # type: Union[None, Packet] 922 self.debug_level = 0 923 self.init_args = args 924 self.init_kargs = kargs 925 self.io = type.__new__(type, "IOnamespace", (), {}) 926 self.oi = type.__new__(type, "IOnamespace", (), {}) 927 self.cmdin = ObjectPipe[Message]("cmdin") 928 self.cmdout = ObjectPipe[Message]("cmdout") 929 self.ioin = {} 930 self.ioout = {} 931 self.packets = PacketList() # type: PacketList 932 for n in self.__class__.ionames: 933 extfd = external_fd.get(n) 934 if not isinstance(extfd, tuple): 935 extfd = (extfd, extfd) 936 ioin, ioout = extfd 937 if ioin is None: 938 ioin = ObjectPipe("ioin") 939 else: 940 ioin = self._IO_fdwrapper(ioin, None) 941 if ioout is None: 942 ioout = ObjectPipe("ioout") 943 else: 944 ioout = self._IO_fdwrapper(None, ioout) 945 946 self.ioin[n] = ioin 947 self.ioout[n] = ioout 948 ioin.ioname = n 949 ioout.ioname = n 950 setattr(self.io, n, self._IO_mixer(ioout, ioin)) 951 setattr(self.oi, n, self._IO_mixer(ioin, ioout)) 952 953 for stname in self.states: 954 setattr(self, stname, 955 _instance_state(getattr(self, stname))) 956 957 self.start() 958 959 def parse_args(self, debug=0, store=0, **kargs): 960 # type: (int, int, Any) -> None 961 self.debug_level = debug 962 if debug: 963 conf.logLevel = logging.DEBUG 964 self.socket_kargs = kargs 965 self.store_packets = store 966 967 @classmethod 968 def spawn(cls, 969 port: int, 970 iface: Optional[_GlobInterfaceType] = None, 971 bg: bool = False, 972 **kwargs: Any) -> Optional[socket.socket]: 973 """ 974 Spawn a TCP server that listens for connections and start the automaton 975 for each new client. 976 977 :param port: the port to listen to 978 :param bg: background mode? (default: False) 979 980 Note that in background mode, you shall close the TCP server as such:: 981 982 srv = MyAutomaton.spawn(8080, bg=True) 983 srv.shutdown(socket.SHUT_RDWR) # important 984 srv.close() 985 """ 986 from scapy.arch import get_if_addr 987 # create server sock and bind it 988 ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 989 local_ip = get_if_addr(iface or conf.iface) 990 try: 991 ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 992 except OSError: 993 pass 994 ssock.bind((local_ip, port)) 995 ssock.listen(5) 996 clients = [] 997 if kwargs.get("verb", True): 998 print(conf.color_theme.green( 999 "Server %s started listening on %s" % ( 1000 cls.__name__, 1001 (local_ip, port), 1002 ) 1003 )) 1004 1005 def _run() -> None: 1006 # Wait for clients forever 1007 try: 1008 while True: 1009 atmt_server = None 1010 clientsocket, address = ssock.accept() 1011 if kwargs.get("verb", True): 1012 print(conf.color_theme.gold( 1013 "\u2503 Connection received from %s" % repr(address) 1014 )) 1015 try: 1016 # Start atmt class with socket 1017 if cls.socketcls is not None: 1018 sock = cls.socketcls(clientsocket, cls.pkt_cls) 1019 else: 1020 sock = clientsocket 1021 atmt_server = cls( 1022 sock=sock, 1023 iface=iface, **kwargs 1024 ) 1025 except OSError: 1026 if atmt_server is not None: 1027 atmt_server.destroy() 1028 if kwargs.get("verb", True): 1029 print("X Connection aborted.") 1030 if kwargs.get("debug", 0) > 0: 1031 traceback.print_exc() 1032 continue 1033 clients.append((atmt_server, clientsocket)) 1034 # start atmt 1035 atmt_server.runbg() 1036 # housekeeping 1037 for atmt, clientsocket in clients: 1038 if not atmt.isrunning(): 1039 atmt.destroy() 1040 except KeyboardInterrupt: 1041 print("X Exiting.") 1042 ssock.shutdown(socket.SHUT_RDWR) 1043 except OSError: 1044 print("X Server closed.") 1045 if kwargs.get("debug", 0) > 0: 1046 traceback.print_exc() 1047 finally: 1048 for atmt, clientsocket in clients: 1049 try: 1050 atmt.forcestop(wait=False) 1051 atmt.destroy() 1052 except Exception: 1053 pass 1054 try: 1055 clientsocket.shutdown(socket.SHUT_RDWR) 1056 clientsocket.close() 1057 except Exception: 1058 pass 1059 ssock.close() 1060 if bg: 1061 # Background 1062 threading.Thread(target=_run).start() 1063 return ssock 1064 else: 1065 # Non-background 1066 _run() 1067 return None 1068 1069 def master_filter(self, pkt): 1070 # type: (Packet) -> bool 1071 return True 1072 1073 def my_send(self, pkt, **kwargs): 1074 # type: (Packet, **Any) -> None 1075 if not self.send_sock: 1076 raise ValueError("send_sock is None !") 1077 self.send_sock.send(pkt, **kwargs) 1078 1079 def update_sock(self, sock): 1080 # type: (SuperSocket) -> None 1081 """ 1082 Update the socket used by the automata. 1083 Typically used in an eof event to reconnect. 1084 """ 1085 self.sock = sock 1086 if self.listen_sock is not None: 1087 self.listen_sock = self.sock 1088 if self.send_sock: 1089 self.send_sock = self.sock 1090 1091 def timer_by_name(self, name): 1092 # type: (str) -> Optional[Timer] 1093 for _, timers in self.timeout.items(): 1094 for timer in timers: # type: Timer 1095 if timer._func.atmt_condname == name: 1096 return timer 1097 return None 1098 1099 # Utility classes and exceptions 1100 class _IO_fdwrapper: 1101 def __init__(self, 1102 rd, # type: Union[int, ObjectPipe[bytes], None] 1103 wr # type: Union[int, ObjectPipe[bytes], None] 1104 ): 1105 # type: (...) -> None 1106 self.rd = rd 1107 self.wr = wr 1108 if isinstance(self.rd, socket.socket): 1109 self.__selectable_force_select__ = True 1110 1111 def fileno(self): 1112 # type: () -> int 1113 if isinstance(self.rd, int): 1114 return self.rd 1115 elif self.rd: 1116 return self.rd.fileno() 1117 return 0 1118 1119 def read(self, n=65535): 1120 # type: (int) -> Optional[bytes] 1121 if isinstance(self.rd, int): 1122 return os.read(self.rd, n) 1123 elif self.rd: 1124 return self.rd.recv(n) 1125 return None 1126 1127 def write(self, msg): 1128 # type: (bytes) -> int 1129 if isinstance(self.wr, int): 1130 return os.write(self.wr, msg) 1131 elif self.wr: 1132 return self.wr.send(msg) 1133 return 0 1134 1135 def recv(self, n=65535): 1136 # type: (int) -> Optional[bytes] 1137 return self.read(n) 1138 1139 def send(self, msg): 1140 # type: (bytes) -> int 1141 return self.write(msg) 1142 1143 class _IO_mixer: 1144 def __init__(self, 1145 rd, # type: ObjectPipe[Any] 1146 wr, # type: ObjectPipe[Any] 1147 ): 1148 # type: (...) -> None 1149 self.rd = rd 1150 self.wr = wr 1151 1152 def fileno(self): 1153 # type: () -> Any 1154 if isinstance(self.rd, ObjectPipe): 1155 return self.rd.fileno() 1156 return self.rd 1157 1158 def recv(self, n=None): 1159 # type: (Optional[int]) -> Any 1160 return self.rd.recv(n) 1161 1162 def read(self, n=None): 1163 # type: (Optional[int]) -> Any 1164 return self.recv(n) 1165 1166 def send(self, msg): 1167 # type: (str) -> int 1168 return self.wr.send(msg) 1169 1170 def write(self, msg): 1171 # type: (str) -> int 1172 return self.send(msg) 1173 1174 class AutomatonException(Exception): 1175 def __init__(self, msg, state=None, result=None): 1176 # type: (str, Optional[Message], Optional[str]) -> None 1177 Exception.__init__(self, msg) 1178 self.state = state 1179 self.result = result 1180 1181 class AutomatonError(AutomatonException): 1182 pass 1183 1184 class ErrorState(AutomatonException): 1185 pass 1186 1187 class Stuck(AutomatonException): 1188 pass 1189 1190 class AutomatonStopped(AutomatonException): 1191 pass 1192 1193 class Breakpoint(AutomatonStopped): 1194 pass 1195 1196 class Singlestep(AutomatonStopped): 1197 pass 1198 1199 class InterceptionPoint(AutomatonStopped): 1200 def __init__(self, msg, state=None, result=None, packet=None): 1201 # type: (str, Optional[Message], Optional[str], Optional[Packet]) -> None 1202 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) 1203 self.packet = packet 1204 1205 class CommandMessage(AutomatonException): 1206 pass 1207 1208 # Services 1209 def debug(self, lvl, msg): 1210 # type: (int, str) -> None 1211 if self.debug_level >= lvl: 1212 log_runtime.debug(msg) 1213 1214 def isrunning(self): 1215 # type: () -> bool 1216 return self.started.locked() 1217 1218 def send(self, pkt, **kwargs): 1219 # type: (Packet, **Any) -> None 1220 if self.state.state in self.interception_points: 1221 self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary()) 1222 self.intercepted_packet = pkt 1223 self.cmdout.send( 1224 Message(type=_ATMT_Command.INTERCEPT, 1225 state=self.state, pkt=pkt) 1226 ) 1227 cmd = self.cmdin.recv() 1228 if not cmd: 1229 self.debug(3, "CANCELLED") 1230 return 1231 self.intercepted_packet = None 1232 if cmd.type == _ATMT_Command.REJECT: 1233 self.debug(3, "INTERCEPT: packet rejected") 1234 return 1235 elif cmd.type == _ATMT_Command.REPLACE: 1236 pkt = cmd.pkt 1237 self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary()) # noqa: E501 1238 elif cmd.type == _ATMT_Command.ACCEPT: 1239 self.debug(3, "INTERCEPT: packet accepted") 1240 else: 1241 raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type) # noqa: E501 1242 self.my_send(pkt, **kwargs) 1243 self.debug(3, "SENT : %s" % pkt.summary()) 1244 1245 if self.store_packets: 1246 self.packets.append(pkt.copy()) 1247 1248 def __iter__(self): 1249 # type: () -> Automaton 1250 return self 1251 1252 def __del__(self): 1253 # type: () -> None 1254 self.destroy() 1255 1256 def _run_condition(self, cond, *args, **kargs): 1257 # type: (_StateWrapper, Any, Any) -> None 1258 try: 1259 self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 1260 cond(self, *args, **kargs) 1261 except ATMT.NewStateRequested as state_req: 1262 self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) # noqa: E501 1263 if cond.atmt_type == ATMT.RECV: 1264 if self.store_packets: 1265 self.packets.append(args[0]) 1266 for action in self.actions[cond.atmt_condname]: 1267 self.debug(2, " + Running action [%s]" % action.__name__) 1268 action(self, *state_req.action_args, **state_req.action_kargs) 1269 raise 1270 except Exception as e: 1271 self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) # noqa: E501 1272 raise 1273 else: 1274 self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) # noqa: E501 1275 1276 def _do_start(self, *args, **kargs): 1277 # type: (Any, Any) -> None 1278 ready = threading.Event() 1279 _t = threading.Thread( 1280 target=self._do_control, 1281 args=(ready,) + (args), 1282 kwargs=kargs, 1283 name="scapy.automaton _do_start" 1284 ) 1285 _t.daemon = True 1286 _t.start() 1287 ready.wait() 1288 1289 def _do_control(self, ready, *args, **kargs): 1290 # type: (threading.Event, Any, Any) -> None 1291 with self.started: 1292 self.threadid = threading.current_thread().ident 1293 if self.threadid is None: 1294 self.threadid = 0 1295 1296 # Update default parameters 1297 a = args + self.init_args[len(args):] 1298 k = self.init_kargs.copy() 1299 k.update(kargs) 1300 self.parse_args(*a, **k) 1301 1302 # Start the automaton 1303 self.state = self.initial_states[0](self) 1304 self.send_sock = self.sock or self.send_sock_class(**self.socket_kargs) 1305 if self.recv_conditions: 1306 # Only start a receiving socket if we have at least one recv_conditions 1307 self.listen_sock = self.sock or self.recv_sock_class(**self.socket_kargs) # noqa: E501 1308 self.packets = PacketList(name="session[%s]" % self.__class__.__name__) 1309 1310 singlestep = True 1311 iterator = self._do_iter() 1312 self.debug(3, "Starting control thread [tid=%i]" % self.threadid) 1313 # Sync threads 1314 ready.set() 1315 try: 1316 while True: 1317 c = self.cmdin.recv() 1318 if c is None: 1319 return None 1320 self.debug(5, "Received command %s" % c.type) 1321 if c.type == _ATMT_Command.RUN: 1322 singlestep = False 1323 elif c.type == _ATMT_Command.NEXT: 1324 singlestep = True 1325 elif c.type == _ATMT_Command.FREEZE: 1326 continue 1327 elif c.type == _ATMT_Command.STOP: 1328 if self.stop_state: 1329 # There is a stop state 1330 self.state = self.stop_state() 1331 iterator = self._do_iter() 1332 else: 1333 # Act as FORCESTOP 1334 break 1335 elif c.type == _ATMT_Command.FORCESTOP: 1336 break 1337 while True: 1338 state = next(iterator) 1339 if isinstance(state, self.CommandMessage): 1340 break 1341 elif isinstance(state, self.Breakpoint): 1342 c = Message(type=_ATMT_Command.BREAKPOINT, state=state) # noqa: E501 1343 self.cmdout.send(c) 1344 break 1345 if singlestep: 1346 c = Message(type=_ATMT_Command.SINGLESTEP, state=state) # noqa: E501 1347 self.cmdout.send(c) 1348 break 1349 except (StopIteration, RuntimeError): 1350 c = Message(type=_ATMT_Command.END, 1351 result=self.final_state_output) 1352 self.cmdout.send(c) 1353 except Exception as e: 1354 exc_info = sys.exc_info() 1355 self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, "".join(traceback.format_exception(*exc_info)))) # noqa: E501 1356 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) # noqa: E501 1357 self.cmdout.send(m) 1358 self.debug(3, "Stopping control thread (tid=%i)" % self.threadid) 1359 self.threadid = None 1360 if self.listen_sock: 1361 self.listen_sock.close() 1362 if self.send_sock: 1363 self.send_sock.close() 1364 1365 def _do_iter(self): 1366 # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501 1367 while True: 1368 try: 1369 self.debug(1, "## state=[%s]" % self.state.state) 1370 1371 # Entering a new state. First, call new state function 1372 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: # noqa: E501 1373 self.breakpointed = self.state.state 1374 yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, # noqa: E501 1375 state=self.state.state) 1376 self.breakpointed = None 1377 state_output = self.state.run() 1378 if self.state.error: 1379 raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), # noqa: E501 1380 result=state_output, state=self.state.state) # noqa: E501 1381 if self.state.final: 1382 self.final_state_output = state_output 1383 return 1384 1385 if state_output is None: 1386 state_output = () 1387 elif not isinstance(state_output, list): 1388 state_output = state_output, 1389 1390 timers = self.timeout[self.state.state] 1391 # If there are commandMessage, we should skip immediate 1392 # conditions. 1393 if not select_objects([self.cmdin], 0): 1394 # Then check immediate conditions 1395 for cond in self.conditions[self.state.state]: 1396 self._run_condition(cond, *state_output) 1397 1398 # If still there and no conditions left, we are stuck! 1399 if (len(self.recv_conditions[self.state.state]) == 0 and 1400 len(self.ioevents[self.state.state]) == 0 and 1401 timers.count() == 0): 1402 raise self.Stuck("stuck in [%s]" % self.state.state, 1403 state=self.state.state, 1404 result=state_output) 1405 1406 # Finally listen and pay attention to timeouts 1407 timers.reset() 1408 time_previous = time.time() 1409 1410 fds = [self.cmdin] # type: List[Union[SuperSocket, ObjectPipe[Any]]] 1411 select_func = select_objects 1412 if self.listen_sock and self.recv_conditions[self.state.state]: 1413 fds.append(self.listen_sock) 1414 select_func = self.listen_sock.select # type: ignore 1415 for ioev in self.ioevents[self.state.state]: 1416 fds.append(self.ioin[ioev.atmt_ioname]) 1417 while True: 1418 time_current = time.time() 1419 timers.decrement(time_current - time_previous) 1420 time_previous = time_current 1421 for timer in timers.expired(): 1422 self._run_condition(timer._func, *state_output) 1423 remain = timers.until_next() 1424 1425 self.debug(5, "Select on %r" % fds) 1426 r = select_func(fds, remain) 1427 self.debug(5, "Selected %r" % r) 1428 for fd in r: 1429 self.debug(5, "Looking at %r" % fd) 1430 if fd == self.cmdin: 1431 yield self.CommandMessage("Received command message") # noqa: E501 1432 elif fd == self.listen_sock: 1433 try: 1434 pkt = self.listen_sock.recv() 1435 except EOFError: 1436 # Socket was closed abruptly. This will likely only 1437 # ever happen when a client socket is passed to the 1438 # automaton (not the case when the automaton is 1439 # listening on a promiscuous conf.L2sniff) 1440 self.listen_sock.close() 1441 # False so that it is still reset by update_sock 1442 self.listen_sock = False # type: ignore 1443 fds.remove(fd) 1444 if self.state.state in self.eofs: 1445 # There is an eof state 1446 eof = self.eofs[self.state.state] 1447 self.debug(2, "Condition EOF [%s] taken" % eof.__name__) # noqa: E501 1448 raise self.eofs[self.state.state](self) 1449 else: 1450 # There isn't. Therefore, it's a closing condition. 1451 raise EOFError("Socket ended arbruptly.") 1452 if pkt is not None: 1453 if self.master_filter(pkt): 1454 self.debug(3, "RECVD: %s" % pkt.summary()) # noqa: E501 1455 for rcvcond in self.recv_conditions[self.state.state]: # noqa: E501 1456 self._run_condition(rcvcond, pkt, *state_output) # noqa: E501 1457 else: 1458 self.debug(4, "FILTR: %s" % pkt.summary()) # noqa: E501 1459 else: 1460 self.debug(3, "IOEVENT on %s" % fd.ioname) 1461 for ioevt in self.ioevents[self.state.state]: 1462 if ioevt.atmt_ioname == fd.ioname: 1463 self._run_condition(ioevt, fd, *state_output) # noqa: E501 1464 1465 except ATMT.NewStateRequested as state_req: 1466 self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state)) # noqa: E501 1467 self.state = state_req 1468 yield state_req 1469 1470 def __repr__(self): 1471 # type: () -> str 1472 return "<Automaton %s [%s]>" % ( 1473 self.__class__.__name__, 1474 ["HALTED", "RUNNING"][self.isrunning()] 1475 ) 1476 1477 # Public API 1478 def add_interception_points(self, *ipts): 1479 # type: (Any) -> None 1480 for ipt in ipts: 1481 if hasattr(ipt, "atmt_state"): 1482 ipt = ipt.atmt_state 1483 self.interception_points.add(ipt) 1484 1485 def remove_interception_points(self, *ipts): 1486 # type: (Any) -> None 1487 for ipt in ipts: 1488 if hasattr(ipt, "atmt_state"): 1489 ipt = ipt.atmt_state 1490 self.interception_points.discard(ipt) 1491 1492 def add_breakpoints(self, *bps): 1493 # type: (Any) -> None 1494 for bp in bps: 1495 if hasattr(bp, "atmt_state"): 1496 bp = bp.atmt_state 1497 self.breakpoints.add(bp) 1498 1499 def remove_breakpoints(self, *bps): 1500 # type: (Any) -> None 1501 for bp in bps: 1502 if hasattr(bp, "atmt_state"): 1503 bp = bp.atmt_state 1504 self.breakpoints.discard(bp) 1505 1506 def start(self, *args, **kargs): 1507 # type: (Any, Any) -> None 1508 if self.isrunning(): 1509 raise ValueError("Already started") 1510 # Start the control thread 1511 self._do_start(*args, **kargs) 1512 1513 def run(self, 1514 resume=None, # type: Optional[Message] 1515 wait=True # type: Optional[bool] 1516 ): 1517 # type: (...) -> Any 1518 if resume is None: 1519 resume = Message(type=_ATMT_Command.RUN) 1520 self.cmdin.send(resume) 1521 if wait: 1522 try: 1523 c = self.cmdout.recv() 1524 if c is None: 1525 return None 1526 except KeyboardInterrupt: 1527 self.cmdin.send(Message(type=_ATMT_Command.FREEZE)) 1528 return None 1529 if c.type == _ATMT_Command.END: 1530 return c.result 1531 elif c.type == _ATMT_Command.INTERCEPT: 1532 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) # noqa: E501 1533 elif c.type == _ATMT_Command.SINGLESTEP: 1534 raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state) # noqa: E501 1535 elif c.type == _ATMT_Command.BREAKPOINT: 1536 raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state) # noqa: E501 1537 elif c.type == _ATMT_Command.EXCEPTION: 1538 # this code comes from the `six` module (`.reraise()`) 1539 # to raise an exception with specified exc_info. 1540 value = c.exc_info[0]() if c.exc_info[1] is None else c.exc_info[1] # type: ignore # noqa: E501 1541 if value.__traceback__ is not c.exc_info[2]: 1542 raise value.with_traceback(c.exc_info[2]) 1543 raise value 1544 return None 1545 1546 def runbg(self, resume=None, wait=False): 1547 # type: (Optional[Message], Optional[bool]) -> None 1548 self.run(resume, wait) 1549 1550 def __next__(self): 1551 # type: () -> Any 1552 return self.run(resume=Message(type=_ATMT_Command.NEXT)) 1553 1554 def _flush_inout(self): 1555 # type: () -> None 1556 # Flush command pipes 1557 for cmd in [self.cmdin, self.cmdout]: 1558 cmd.clear() 1559 1560 def destroy(self): 1561 # type: () -> None 1562 """ 1563 Destroys a stopped Automaton: this cleanups all opened file descriptors. 1564 Required on PyPy for instance where the garbage collector behaves differently. 1565 """ 1566 if not hasattr(self, "started"): 1567 return # was never started. 1568 if self.isrunning(): 1569 raise ValueError("Can't close running Automaton ! Call stop() beforehand") 1570 # Close command pipes 1571 self.cmdin.close() 1572 self.cmdout.close() 1573 self._flush_inout() 1574 # Close opened ioins/ioouts 1575 for i in itertools.chain(self.ioin.values(), self.ioout.values()): 1576 if isinstance(i, ObjectPipe): 1577 i.close() 1578 1579 def stop(self, wait=True): 1580 # type: (bool) -> None 1581 try: 1582 self.cmdin.send(Message(type=_ATMT_Command.STOP)) 1583 except OSError: 1584 pass 1585 if wait: 1586 with self.started: 1587 self._flush_inout() 1588 1589 def forcestop(self, wait=True): 1590 # type: (bool) -> None 1591 try: 1592 self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP)) 1593 except OSError: 1594 pass 1595 if wait: 1596 with self.started: 1597 self._flush_inout() 1598 1599 def restart(self, *args, **kargs): 1600 # type: (Any, Any) -> None 1601 self.stop() 1602 self.start(*args, **kargs) 1603 1604 def accept_packet(self, 1605 pkt=None, # type: Optional[Packet] 1606 wait=False # type: Optional[bool] 1607 ): 1608 # type: (...) -> Any 1609 rsm = Message() 1610 if pkt is None: 1611 rsm.type = _ATMT_Command.ACCEPT 1612 else: 1613 rsm.type = _ATMT_Command.REPLACE 1614 rsm.pkt = pkt 1615 return self.run(resume=rsm, wait=wait) 1616 1617 def reject_packet(self, 1618 wait=False # type: Optional[bool] 1619 ): 1620 # type: (...) -> Any 1621 rsm = Message(type=_ATMT_Command.REJECT) 1622 return self.run(resume=rsm, wait=wait) 1623