• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# SPDX-License-Identifier: GPL-2.0-only
2# This file is part of Scapy
3# See https://scapy.net/ for more information
4# Copyright (C) Philippe Biondi <phil@secdev.org>
5# Copyright (C) Gabriel Potter
6
7"""
8Automata with states, transitions and actions.
9
10TODO:
11    - add documentation for ioevent, as_supersocket...
12"""
13
14import ctypes
15import itertools
16import logging
17import os
18import random
19import socket
20import sys
21import threading
22import time
23import traceback
24import types
25
26import select
27from collections import deque
28
29from scapy.config import conf
30from scapy.consts import WINDOWS
31from scapy.data import MTU
32from scapy.error import log_runtime, warning
33from scapy.interfaces import _GlobInterfaceType
34from scapy.packet import Packet
35from scapy.plist import PacketList
36from scapy.supersocket import SuperSocket, StreamSocket
37from scapy.utils import do_graph
38
39# Typing imports
40from typing import (
41    Any,
42    Callable,
43    Deque,
44    Dict,
45    Generic,
46    Iterable,
47    Iterator,
48    List,
49    Optional,
50    Set,
51    Tuple,
52    Type,
53    TypeVar,
54    Union,
55    cast,
56)
57from scapy.compat import DecoratorCallable
58
59
60# winsock.h
61FD_READ = 0x00000001
62
63
64def select_objects(inputs, remain):
65    # type: (Iterable[Any], Union[float, int, None]) -> List[Any]
66    """
67    Select objects. Same than:
68    ``select.select(inputs, [], [], remain)``
69
70    But also works on Windows, only on objects whose fileno() returns
71    a Windows event. For simplicity, just use `ObjectPipe()` as a queue
72    that you can select on whatever the platform is.
73
74    If you want an object to be always included in the output of
75    select_objects (i.e. it's not selectable), just make fileno()
76    return a strictly negative value.
77
78    Example:
79
80        >>> a, b = ObjectPipe("a"), ObjectPipe("b")
81        >>> b.send("test")
82        >>> select_objects([a, b], 1)
83        [b]
84
85    :param inputs: objects to process
86    :param remain: timeout. If 0, poll. If None, block.
87    """
88    if not WINDOWS:
89        return select.select(inputs, [], [], remain)[0]
90    inputs = list(inputs)
91    events = []
92    created = []
93    results = set()
94    for i in inputs:
95        if getattr(i, "__selectable_force_select__", False):
96            # Native socket.socket object. We would normally use select.select.
97            evt = ctypes.windll.ws2_32.WSACreateEvent()
98            created.append(evt)
99            res = ctypes.windll.ws2_32.WSAEventSelect(
100                ctypes.c_void_p(i.fileno()),
101                evt,
102                FD_READ
103            )
104            if res == 0:
105                # Was a socket
106                events.append(evt)
107            else:
108                # Fallback to normal event
109                events.append(i.fileno())
110        elif i.fileno() < 0:
111            # Special case: On Windows, we consider that an object that returns
112            # a negative fileno (impossible), is always readable. This is used
113            # in very few places but important (e.g. PcapReader), where we have
114            # no valid fileno (and will stop on EOFError).
115            results.add(i)
116            remain = 0
117        else:
118            events.append(i.fileno())
119    if events:
120        # 0xFFFFFFFF = INFINITE
121        remainms = int(remain * 1000 if remain is not None else 0xFFFFFFFF)
122        if len(events) == 1:
123            res = ctypes.windll.kernel32.WaitForSingleObject(
124                ctypes.c_void_p(events[0]),
125                remainms
126            )
127        else:
128            # Sadly, the only way to emulate select() is to first check
129            # if any object is available using WaitForMultipleObjects
130            # then poll the others.
131            res = ctypes.windll.kernel32.WaitForMultipleObjects(
132                len(events),
133                (ctypes.c_void_p * len(events))(
134                    *events
135                ),
136                False,
137                remainms
138            )
139        if res != 0xFFFFFFFF and res != 0x00000102:  # Failed or Timeout
140            results.add(inputs[res])
141            if len(events) > 1:
142                # Now poll the others, if any
143                for i, evt in enumerate(events):
144                    res = ctypes.windll.kernel32.WaitForSingleObject(
145                        ctypes.c_void_p(evt),
146                        0  # poll: don't wait
147                    )
148                    if res == 0:
149                        results.add(inputs[i])
150    # Cleanup created events, if any
151    for evt in created:
152        ctypes.windll.ws2_32.WSACloseEvent(evt)
153    return list(results)
154
155
156_T = TypeVar("_T")
157
158
159class ObjectPipe(Generic[_T]):
160    def __init__(self, name=None):
161        # type: (Optional[str]) -> None
162        self.name = name or "ObjectPipe"
163        self.closed = False
164        self.__rd, self.__wr = os.pipe()
165        self.__queue = deque()  # type: Deque[_T]
166        if WINDOWS:
167            self._wincreate()
168
169    if WINDOWS:
170        def _wincreate(self):
171            # type: () -> None
172            self._fd = cast(int, ctypes.windll.kernel32.CreateEventA(
173                None, True, False,
174                ctypes.create_string_buffer(b"ObjectPipe %f" % random.random())
175            ))
176
177        def _winset(self):
178            # type: () -> None
179            if ctypes.windll.kernel32.SetEvent(ctypes.c_void_p(self._fd)) == 0:
180                warning(ctypes.FormatError(ctypes.GetLastError()))
181
182        def _winreset(self):
183            # type: () -> None
184            if ctypes.windll.kernel32.ResetEvent(ctypes.c_void_p(self._fd)) == 0:
185                warning(ctypes.FormatError(ctypes.GetLastError()))
186
187        def _winclose(self):
188            # type: () -> None
189            if ctypes.windll.kernel32.CloseHandle(ctypes.c_void_p(self._fd)) == 0:
190                warning(ctypes.FormatError(ctypes.GetLastError()))
191
192    def fileno(self):
193        # type: () -> int
194        if WINDOWS:
195            return self._fd
196        return self.__rd
197
198    def send(self, obj):
199        # type: (_T) -> int
200        self.__queue.append(obj)
201        if WINDOWS:
202            self._winset()
203        os.write(self.__wr, b"X")
204        return 1
205
206    def write(self, obj):
207        # type: (_T) -> None
208        self.send(obj)
209
210    def empty(self):
211        # type: () -> bool
212        return not bool(self.__queue)
213
214    def flush(self):
215        # type: () -> None
216        pass
217
218    def recv(self, n=0, options=socket.MsgFlag(0)):
219        # type: (Optional[int], socket.MsgFlag) -> Optional[_T]
220        if self.closed:
221            raise EOFError
222        if options & socket.MSG_PEEK:
223            if self.__queue:
224                return self.__queue[0]
225            return None
226        os.read(self.__rd, 1)
227        elt = self.__queue.popleft()
228        if WINDOWS and not self.__queue:
229            self._winreset()
230        return elt
231
232    def read(self, n=0):
233        # type: (Optional[int]) -> Optional[_T]
234        return self.recv(n)
235
236    def clear(self):
237        # type: () -> None
238        if not self.closed:
239            while not self.empty():
240                self.recv()
241
242    def close(self):
243        # type: () -> None
244        if not self.closed:
245            self.closed = True
246            os.close(self.__rd)
247            os.close(self.__wr)
248            if WINDOWS:
249                try:
250                    self._winclose()
251                except ImportError:
252                    # Python is shutting down
253                    pass
254
255    def __repr__(self):
256        # type: () -> str
257        return "<%s at %s>" % (self.name, id(self))
258
259    def __del__(self):
260        # type: () -> None
261        self.close()
262
263    @staticmethod
264    def select(sockets, remain=conf.recv_poll_rate):
265        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
266        # Only handle ObjectPipes
267        results = []
268        for s in sockets:
269            if s.closed:  # allow read to trigger EOF
270                results.append(s)
271        if results:
272            return results
273        return select_objects(sockets, remain)
274
275
276class Message:
277    type = None        # type: str
278    pkt = None         # type: Packet
279    result = None      # type: str
280    state = None       # type: Message
281    exc_info = None    # type: Union[Tuple[None, None, None], Tuple[BaseException, Exception, types.TracebackType]] # noqa: E501
282
283    def __init__(self, **args):
284        # type: (Any) -> None
285        self.__dict__.update(args)
286
287    def __repr__(self):
288        # type: () -> str
289        return "<Message %s>" % " ".join(
290            "%s=%r" % (k, v)
291            for k, v in self.__dict__.items()
292            if not k.startswith("_")
293        )
294
295
296class Timer():
297    def __init__(self, time, prio=0, autoreload=False):
298        # type: (Union[int, float], int, bool) -> None
299        self._timeout = float(time)  # type: float
300        self._time = 0  # type: float
301        self._just_expired = True
302        self._expired = True
303        self._prio = prio
304        self._func = _StateWrapper()
305        self._autoreload = autoreload
306
307    def get(self):
308        # type: () -> float
309        return self._timeout
310
311    def set(self, val):
312        # type: (float) -> None
313        self._timeout = val
314
315    def _reset(self):
316        # type: () -> None
317        self._time = self._timeout
318        self._expired = False
319        self._just_expired = False
320
321    def _reset_just_expired(self):
322        # type: () -> None
323        self._just_expired = False
324
325    def _running(self):
326        # type: () -> bool
327        return self._time > 0
328
329    def _remaining(self):
330        # type: () -> float
331        return max(self._time, 0)
332
333    def _decrement(self, time):
334        # type: (float) -> None
335        self._time -= time
336        if self._time <= 0:
337            if not self._expired:
338                self._just_expired = True
339                if self._autoreload:
340                    # take overshoot into account
341                    self._time = self._timeout + self._time
342                else:
343                    self._expired = True
344                    self._time = 0
345
346    def __lt__(self, obj):
347        # type: (Timer) -> bool
348        return ((self._time < obj._time) if self._time != obj._time
349                else (self._prio < obj._prio))
350
351    def __gt__(self, obj):
352        # type: (Timer) -> bool
353        return ((self._time > obj._time) if self._time != obj._time
354                else (self._prio > obj._prio))
355
356    def __eq__(self, obj):
357        # type: (Any) -> bool
358        if not isinstance(obj, Timer):
359            raise NotImplementedError()
360        return (self._time == obj._time) and (self._prio == obj._prio)
361
362    def __repr__(self):
363        # type: () -> str
364        return "<Timer %f(%f)>" % (self._time, self._timeout)
365
366
367class _TimerList():
368    def __init__(self):
369        # type: () -> None
370        self.timers = []  # type: list[Timer]
371
372    def add_timer(self, timer):
373        # type: (Timer) -> None
374        self.timers.append(timer)
375
376    def reset(self):
377        # type: () -> None
378        for t in self.timers:
379            t._reset()
380
381    def decrement(self, time):
382        # type: (float) -> None
383        for t in self.timers:
384            t._decrement(time)
385
386    def expired(self):
387        # type: () -> list[Timer]
388        lst = [t for t in self.timers if t._just_expired]
389        lst.sort(key=lambda x: x._prio, reverse=True)
390        for t in lst:
391            t._reset_just_expired()
392        return lst
393
394    def until_next(self):
395        # type: () -> Optional[float]
396        try:
397            return min([t._remaining() for t in self.timers if t._running()])
398        except ValueError:
399            return None  # None means blocking
400
401    def count(self):
402        # type: () -> int
403        return len(self.timers)
404
405    def __iter__(self):
406        # type: () -> Iterator[Timer]
407        return self.timers.__iter__()
408
409    def __repr__(self):
410        # type: () -> str
411        return self.timers.__repr__()
412
413
414class _instance_state:
415    def __init__(self, instance):
416        # type: (Any) -> None
417        self.__self__ = instance.__self__
418        self.__func__ = instance.__func__
419        self.__self__.__class__ = instance.__self__.__class__
420
421    def __getattr__(self, attr):
422        # type: (str) -> Any
423        return getattr(self.__func__, attr)
424
425    def __call__(self, *args, **kargs):
426        # type: (Any, Any) -> Any
427        return self.__func__(self.__self__, *args, **kargs)
428
429    def breaks(self):
430        # type: () -> Any
431        return self.__self__.add_breakpoints(self.__func__)
432
433    def intercepts(self):
434        # type: () -> Any
435        return self.__self__.add_interception_points(self.__func__)
436
437    def unbreaks(self):
438        # type: () -> Any
439        return self.__self__.remove_breakpoints(self.__func__)
440
441    def unintercepts(self):
442        # type: () -> Any
443        return self.__self__.remove_interception_points(self.__func__)
444
445
446##############
447#  Automata  #
448##############
449
450class _StateWrapper:
451    __name__ = None             # type: str
452    atmt_type = None            # type: str
453    atmt_state = None           # type: str
454    atmt_initial = None         # type: int
455    atmt_final = None           # type: int
456    atmt_stop = None            # type: int
457    atmt_error = None           # type: int
458    atmt_origfunc = None        # type: _StateWrapper
459    atmt_prio = None            # type: int
460    atmt_as_supersocket = None  # type: Optional[str]
461    atmt_condname = None        # type: str
462    atmt_ioname = None          # type: str
463    atmt_timeout = None         # type: Timer
464    atmt_cond = None            # type: Dict[str, int]
465    __code__ = None             # type: types.CodeType
466    __call__ = None             # type: Callable[..., ATMT.NewStateRequested]
467
468
469class ATMT:
470    STATE = "State"
471    ACTION = "Action"
472    CONDITION = "Condition"
473    RECV = "Receive condition"
474    TIMEOUT = "Timeout condition"
475    EOF = "EOF condition"
476    IOEVENT = "I/O event"
477
478    class NewStateRequested(Exception):
479        def __init__(self, state_func, automaton, *args, **kargs):
480            # type: (Any, ATMT, Any, Any) -> None
481            self.func = state_func
482            self.state = state_func.atmt_state
483            self.initial = state_func.atmt_initial
484            self.error = state_func.atmt_error
485            self.stop = state_func.atmt_stop
486            self.final = state_func.atmt_final
487            Exception.__init__(self, "Request state [%s]" % self.state)
488            self.automaton = automaton
489            self.args = args
490            self.kargs = kargs
491            self.action_parameters()  # init action parameters
492
493        def action_parameters(self, *args, **kargs):
494            # type: (Any, Any) -> ATMT.NewStateRequested
495            self.action_args = args
496            self.action_kargs = kargs
497            return self
498
499        def run(self):
500            # type: () -> Any
501            return self.func(self.automaton, *self.args, **self.kargs)
502
503        def __repr__(self):
504            # type: () -> str
505            return "NewStateRequested(%s)" % self.state
506
507    @staticmethod
508    def state(initial=0,    # type: int
509              final=0,      # type: int
510              stop=0,       # type: int
511              error=0       # type: int
512              ):
513        # type: (...) -> Callable[[DecoratorCallable], DecoratorCallable]
514        def deco(f, initial=initial, final=final):
515            # type: (_StateWrapper, int, int) -> _StateWrapper
516            f.atmt_type = ATMT.STATE
517            f.atmt_state = f.__name__
518            f.atmt_initial = initial
519            f.atmt_final = final
520            f.atmt_stop = stop
521            f.atmt_error = error
522
523            def _state_wrapper(self, *args, **kargs):
524                # type: (ATMT, Any, Any) -> ATMT.NewStateRequested
525                return ATMT.NewStateRequested(f, self, *args, **kargs)
526
527            state_wrapper = cast(_StateWrapper, _state_wrapper)
528            state_wrapper.__name__ = "%s_wrapper" % f.__name__
529            state_wrapper.atmt_type = ATMT.STATE
530            state_wrapper.atmt_state = f.__name__
531            state_wrapper.atmt_initial = initial
532            state_wrapper.atmt_final = final
533            state_wrapper.atmt_stop = stop
534            state_wrapper.atmt_error = error
535            state_wrapper.atmt_origfunc = f
536            return state_wrapper
537        return deco  # type: ignore
538
539    @staticmethod
540    def action(cond, prio=0):
541        # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper]  # noqa: E501
542        def deco(f, cond=cond):
543            # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
544            if not hasattr(f, "atmt_type"):
545                f.atmt_cond = {}
546            f.atmt_type = ATMT.ACTION
547            f.atmt_cond[cond.atmt_condname] = prio
548            return f
549        return deco
550
551    @staticmethod
552    def condition(state, prio=0):
553        # type: (Any, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper]  # noqa: E501
554        def deco(f, state=state):
555            # type: (_StateWrapper, _StateWrapper) -> Any
556            f.atmt_type = ATMT.CONDITION
557            f.atmt_state = state.atmt_state
558            f.atmt_condname = f.__name__
559            f.atmt_prio = prio
560            return f
561        return deco
562
563    @staticmethod
564    def receive_condition(state, prio=0):
565        # type: (_StateWrapper, int) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper]  # noqa: E501
566        def deco(f, state=state):
567            # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
568            f.atmt_type = ATMT.RECV
569            f.atmt_state = state.atmt_state
570            f.atmt_condname = f.__name__
571            f.atmt_prio = prio
572            return f
573        return deco
574
575    @staticmethod
576    def ioevent(state,                  # type: _StateWrapper
577                name,                   # type: str
578                prio=0,                 # type: int
579                as_supersocket=None     # type: Optional[str]
580                ):
581        # type: (...) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper]  # noqa: E501
582        def deco(f, state=state):
583            # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
584            f.atmt_type = ATMT.IOEVENT
585            f.atmt_state = state.atmt_state
586            f.atmt_condname = f.__name__
587            f.atmt_ioname = name
588            f.atmt_prio = prio
589            f.atmt_as_supersocket = as_supersocket
590            return f
591        return deco
592
593    @staticmethod
594    def timeout(state, timeout):
595        # type: (_StateWrapper, Union[int, float]) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper]  # noqa: E501
596        def deco(f, state=state, timeout=Timer(timeout)):
597            # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
598            f.atmt_type = ATMT.TIMEOUT
599            f.atmt_state = state.atmt_state
600            f.atmt_timeout = timeout
601            f.atmt_timeout._func = f
602            f.atmt_condname = f.__name__
603            return f
604        return deco
605
606    @staticmethod
607    def timer(state, timeout, prio=0):
608        # type: (_StateWrapper, Union[float, int], int) -> Callable[[_StateWrapper, _StateWrapper, Timer], _StateWrapper]  # noqa: E501
609        def deco(f, state=state, timeout=Timer(timeout, prio=prio, autoreload=True)):
610            # type: (_StateWrapper, _StateWrapper, Timer) -> _StateWrapper
611            f.atmt_type = ATMT.TIMEOUT
612            f.atmt_state = state.atmt_state
613            f.atmt_timeout = timeout
614            f.atmt_timeout._func = f
615            f.atmt_condname = f.__name__
616            return f
617        return deco
618
619    @staticmethod
620    def eof(state):
621        # type: (_StateWrapper) -> Callable[[_StateWrapper, _StateWrapper], _StateWrapper]  # noqa: E501
622        def deco(f, state=state):
623            # type: (_StateWrapper, _StateWrapper) -> _StateWrapper
624            f.atmt_type = ATMT.EOF
625            f.atmt_state = state.atmt_state
626            f.atmt_condname = f.__name__
627            return f
628        return deco
629
630
631class _ATMT_Command:
632    RUN = "RUN"
633    NEXT = "NEXT"
634    FREEZE = "FREEZE"
635    STOP = "STOP"
636    FORCESTOP = "FORCESTOP"
637    END = "END"
638    EXCEPTION = "EXCEPTION"
639    SINGLESTEP = "SINGLESTEP"
640    BREAKPOINT = "BREAKPOINT"
641    INTERCEPT = "INTERCEPT"
642    ACCEPT = "ACCEPT"
643    REPLACE = "REPLACE"
644    REJECT = "REJECT"
645
646
647class _ATMT_supersocket(SuperSocket):
648    def __init__(self,
649                 name,          # type: str
650                 ioevent,       # type: str
651                 automaton,     # type: Type[Automaton]
652                 proto,         # type: Callable[[bytes], Any]
653                 *args,         # type: Any
654                 **kargs        # type: Any
655                 ):
656        # type: (...) -> None
657        self.name = name
658        self.ioevent = ioevent
659        self.proto = proto
660        # write, read
661        self.spa, self.spb = ObjectPipe[Any]("spa"), \
662            ObjectPipe[Any]("spb")
663        kargs["external_fd"] = {ioevent: (self.spa, self.spb)}
664        kargs["is_atmt_socket"] = True
665        kargs["atmt_socket"] = self.name
666        self.atmt = automaton(*args, **kargs)
667        self.atmt.runbg()
668
669    def send(self, s):
670        # type: (Any) -> int
671        return self.spa.send(s)
672
673    def fileno(self):
674        # type: () -> int
675        return self.spb.fileno()
676
677    # note: _ATMT_supersocket may return bytes in certain cases, which
678    # is expected. We cheat on typing.
679    def recv(self, n=MTU, **kwargs):  # type: ignore
680        # type: (int, **Any) -> Any
681        r = self.spb.recv(n)
682        if self.proto is not None and r is not None:
683            r = self.proto(r, **kwargs)
684        return r
685
686    def close(self):
687        # type: () -> None
688        if not self.closed:
689            self.atmt.stop()
690            self.atmt.destroy()
691            self.spa.close()
692            self.spb.close()
693            self.closed = True
694
695    @staticmethod
696    def select(sockets, remain=conf.recv_poll_rate):
697        # type: (List[SuperSocket], Optional[float]) -> List[SuperSocket]
698        return select_objects(sockets, remain)
699
700
701class _ATMT_to_supersocket:
702    def __init__(self, name, ioevent, automaton):
703        # type: (str, str, Type[Automaton]) -> None
704        self.name = name
705        self.ioevent = ioevent
706        self.automaton = automaton
707
708    def __call__(self, proto, *args, **kargs):
709        # type: (Callable[[bytes], Any], Any, Any) -> _ATMT_supersocket
710        return _ATMT_supersocket(
711            self.name, self.ioevent, self.automaton,
712            proto, *args, **kargs
713        )
714
715
716class Automaton_metaclass(type):
717    def __new__(cls, name, bases, dct):
718        # type: (str, Tuple[Any], Dict[str, Any]) -> Type[Automaton]
719        cls = super(Automaton_metaclass, cls).__new__(  # type: ignore
720            cls, name, bases, dct
721        )
722        cls.states = {}
723        cls.recv_conditions = {}    # type: Dict[str, List[_StateWrapper]]
724        cls.conditions = {}         # type: Dict[str, List[_StateWrapper]]
725        cls.ioevents = {}           # type: Dict[str, List[_StateWrapper]]
726        cls.timeout = {}            # type: Dict[str, _TimerList]
727        cls.eofs = {}               # type: Dict[str, _StateWrapper]
728        cls.actions = {}            # type: Dict[str, List[_StateWrapper]]
729        cls.initial_states = []     # type: List[_StateWrapper]
730        cls.stop_state = None       # type: Optional[_StateWrapper]
731        cls.ionames = []
732        cls.iosupersockets = []
733
734        members = {}
735        classes = [cls]
736        while classes:
737            c = classes.pop(0)  # order is important to avoid breaking method overloading  # noqa: E501
738            classes += list(c.__bases__)
739            for k, v in c.__dict__.items():  # type: ignore
740                if k not in members:
741                    members[k] = v
742
743        decorated = [v for v in members.values()
744                     if hasattr(v, "atmt_type")]
745
746        for m in decorated:
747            if m.atmt_type == ATMT.STATE:
748                s = m.atmt_state
749                cls.states[s] = m
750                cls.recv_conditions[s] = []
751                cls.ioevents[s] = []
752                cls.conditions[s] = []
753                cls.timeout[s] = _TimerList()
754                if m.atmt_initial:
755                    cls.initial_states.append(m)
756                if m.atmt_stop:
757                    if cls.stop_state is not None:
758                        raise ValueError("There can only be a single stop state !")
759                    cls.stop_state = m
760            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT, ATMT.EOF]:  # noqa: E501
761                cls.actions[m.atmt_condname] = []
762
763        for m in decorated:
764            if m.atmt_type == ATMT.CONDITION:
765                cls.conditions[m.atmt_state].append(m)
766            elif m.atmt_type == ATMT.RECV:
767                cls.recv_conditions[m.atmt_state].append(m)
768            elif m.atmt_type == ATMT.EOF:
769                cls.eofs[m.atmt_state] = m
770            elif m.atmt_type == ATMT.IOEVENT:
771                cls.ioevents[m.atmt_state].append(m)
772                cls.ionames.append(m.atmt_ioname)
773                if m.atmt_as_supersocket is not None:
774                    cls.iosupersockets.append(m)
775            elif m.atmt_type == ATMT.TIMEOUT:
776                cls.timeout[m.atmt_state].add_timer(m.atmt_timeout)
777            elif m.atmt_type == ATMT.ACTION:
778                for co in m.atmt_cond:
779                    cls.actions[co].append(m)
780
781        for v in itertools.chain(
782            cls.conditions.values(),
783            cls.recv_conditions.values(),
784            cls.ioevents.values()
785        ):
786            v.sort(key=lambda x: x.atmt_prio)
787        for condname, actlst in cls.actions.items():
788            actlst.sort(key=lambda x: x.atmt_cond[condname])
789
790        for ioev in cls.iosupersockets:
791            setattr(cls, ioev.atmt_as_supersocket,
792                    _ATMT_to_supersocket(
793                        ioev.atmt_as_supersocket,
794                        ioev.atmt_ioname,
795                        cast(Type["Automaton"], cls)))
796
797        # Inject signature
798        try:
799            import inspect
800            cls.__signature__ = inspect.signature(cls.parse_args)  # type: ignore  # noqa: E501
801        except (ImportError, AttributeError):
802            pass
803
804        return cast(Type["Automaton"], cls)
805
806    def build_graph(self):
807        # type: () -> str
808        s = 'digraph "%s" {\n' % self.__class__.__name__
809
810        se = ""  # Keep initial nodes at the beginning for better rendering
811        for st in self.states.values():
812            if st.atmt_initial:
813                se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state) + se  # noqa: E501
814            elif st.atmt_final:
815                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state  # noqa: E501
816            elif st.atmt_error:
817                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state  # noqa: E501
818            elif st.atmt_stop:
819                se += '\t"%s" [ style=filled, fillcolor=orange, shape=box, root=true ];\n' % st.atmt_state  # noqa: E501
820        s += se
821
822        for st in self.states.values():
823            names = list(
824                st.atmt_origfunc.__code__.co_names +
825                st.atmt_origfunc.__code__.co_consts
826            )
827            while names:
828                n = names.pop()
829                if n in self.states:
830                    s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state, n)
831                elif n in self.__dict__:
832                    # function indirection
833                    if callable(self.__dict__[n]):
834                        names.extend(self.__dict__[n].__code__.co_names)
835                        names.extend(self.__dict__[n].__code__.co_consts)
836
837        for c, sty, k, v in (
838            [("purple", "solid", k, v) for k, v in self.conditions.items()] +
839            [("red", "solid", k, v) for k, v in self.recv_conditions.items()] +
840            [("orange", "solid", k, v) for k, v in self.ioevents.items()] +
841            [("black", "dashed", k, [v]) for k, v in self.eofs.items()]
842        ):
843            for f in v:
844                names = list(f.__code__.co_names + f.__code__.co_consts)
845                while names:
846                    n = names.pop()
847                    if n in self.states:
848                        line = f.atmt_condname
849                        for x in self.actions[f.atmt_condname]:
850                            line += "\\l>[%s]" % x.__name__
851                        s += '\t"%s" -> "%s" [label="%s", color=%s, style=%s];\n' % (
852                            k,
853                            n,
854                            line,
855                            c,
856                            sty,
857                        )
858                    elif n in self.__dict__:
859                        # function indirection
860                        if callable(self.__dict__[n]) and hasattr(self.__dict__[n], "__code__"):  # noqa: E501
861                            names.extend(self.__dict__[n].__code__.co_names)
862                            names.extend(self.__dict__[n].__code__.co_consts)
863        for k, timers in self.timeout.items():
864            for timer in timers:
865                for n in (timer._func.__code__.co_names +
866                          timer._func.__code__.co_consts):
867                    if n in self.states:
868                        line = "%s/%.1fs" % (timer._func.atmt_condname,
869                                             timer.get())
870                        for x in self.actions[timer._func.atmt_condname]:
871                            line += "\\l>[%s]" % x.__name__
872                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k, n, line)  # noqa: E501
873        s += "}\n"
874        return s
875
876    def graph(self, **kargs):
877        # type: (Any) -> Optional[str]
878        s = self.build_graph()
879        return do_graph(s, **kargs)
880
881
882class Automaton(metaclass=Automaton_metaclass):
883    states = {}             # type: Dict[str, _StateWrapper]
884    state = None            # type: ATMT.NewStateRequested
885    recv_conditions = {}    # type: Dict[str, List[_StateWrapper]]
886    conditions = {}         # type: Dict[str, List[_StateWrapper]]
887    eofs = {}               # type: Dict[str, _StateWrapper]
888    ioevents = {}           # type: Dict[str, List[_StateWrapper]]
889    timeout = {}            # type: Dict[str, _TimerList]
890    actions = {}            # type: Dict[str, List[_StateWrapper]]
891    initial_states = []     # type: List[_StateWrapper]
892    stop_state = None       # type: Optional[_StateWrapper]
893    ionames = []            # type: List[str]
894    iosupersockets = []     # type: List[SuperSocket]
895
896    # used for spawn()
897    pkt_cls = conf.raw_layer
898    socketcls = StreamSocket
899
900    # Internals
901    def __init__(self, *args, **kargs):
902        # type: (Any, Any) -> None
903        external_fd = kargs.pop("external_fd", {})
904        if "sock" in kargs:
905            # We use a bi-directional sock
906            self.sock = kargs["sock"]
907        else:
908            # Separate sockets
909            self.sock = None
910            self.send_sock_class = kargs.pop("ll", conf.L3socket)
911            self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
912        self.listen_sock = None  # type: Optional[SuperSocket]
913        self.send_sock = None  # type: Optional[SuperSocket]
914        self.is_atmt_socket = kargs.pop("is_atmt_socket", False)
915        self.atmt_socket = kargs.pop("atmt_socket", None)
916        self.started = threading.Lock()
917        self.threadid = None                # type: Optional[int]
918        self.breakpointed = None
919        self.breakpoints = set()            # type: Set[_StateWrapper]
920        self.interception_points = set()    # type: Set[_StateWrapper]
921        self.intercepted_packet = None      # type: Union[None, Packet]
922        self.debug_level = 0
923        self.init_args = args
924        self.init_kargs = kargs
925        self.io = type.__new__(type, "IOnamespace", (), {})
926        self.oi = type.__new__(type, "IOnamespace", (), {})
927        self.cmdin = ObjectPipe[Message]("cmdin")
928        self.cmdout = ObjectPipe[Message]("cmdout")
929        self.ioin = {}
930        self.ioout = {}
931        self.packets = PacketList()                 # type: PacketList
932        for n in self.__class__.ionames:
933            extfd = external_fd.get(n)
934            if not isinstance(extfd, tuple):
935                extfd = (extfd, extfd)
936            ioin, ioout = extfd
937            if ioin is None:
938                ioin = ObjectPipe("ioin")
939            else:
940                ioin = self._IO_fdwrapper(ioin, None)
941            if ioout is None:
942                ioout = ObjectPipe("ioout")
943            else:
944                ioout = self._IO_fdwrapper(None, ioout)
945
946            self.ioin[n] = ioin
947            self.ioout[n] = ioout
948            ioin.ioname = n
949            ioout.ioname = n
950            setattr(self.io, n, self._IO_mixer(ioout, ioin))
951            setattr(self.oi, n, self._IO_mixer(ioin, ioout))
952
953        for stname in self.states:
954            setattr(self, stname,
955                    _instance_state(getattr(self, stname)))
956
957        self.start()
958
959    def parse_args(self, debug=0, store=0, **kargs):
960        # type: (int, int, Any) -> None
961        self.debug_level = debug
962        if debug:
963            conf.logLevel = logging.DEBUG
964        self.socket_kargs = kargs
965        self.store_packets = store
966
967    @classmethod
968    def spawn(cls,
969              port: int,
970              iface: Optional[_GlobInterfaceType] = None,
971              bg: bool = False,
972              **kwargs: Any) -> Optional[socket.socket]:
973        """
974        Spawn a TCP server that listens for connections and start the automaton
975        for each new client.
976
977        :param port: the port to listen to
978        :param bg: background mode? (default: False)
979
980        Note that in background mode, you shall close the TCP server as such::
981
982            srv = MyAutomaton.spawn(8080, bg=True)
983            srv.shutdown(socket.SHUT_RDWR)  # important
984            srv.close()
985        """
986        from scapy.arch import get_if_addr
987        # create server sock and bind it
988        ssock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
989        local_ip = get_if_addr(iface or conf.iface)
990        try:
991            ssock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
992        except OSError:
993            pass
994        ssock.bind((local_ip, port))
995        ssock.listen(5)
996        clients = []
997        if kwargs.get("verb", True):
998            print(conf.color_theme.green(
999                "Server %s started listening on %s" % (
1000                    cls.__name__,
1001                    (local_ip, port),
1002                )
1003            ))
1004
1005        def _run() -> None:
1006            # Wait for clients forever
1007            try:
1008                while True:
1009                    atmt_server = None
1010                    clientsocket, address = ssock.accept()
1011                    if kwargs.get("verb", True):
1012                        print(conf.color_theme.gold(
1013                            "\u2503 Connection received from %s" % repr(address)
1014                        ))
1015                    try:
1016                        # Start atmt class with socket
1017                        if cls.socketcls is not None:
1018                            sock = cls.socketcls(clientsocket, cls.pkt_cls)
1019                        else:
1020                            sock = clientsocket
1021                        atmt_server = cls(
1022                            sock=sock,
1023                            iface=iface, **kwargs
1024                        )
1025                    except OSError:
1026                        if atmt_server is not None:
1027                            atmt_server.destroy()
1028                        if kwargs.get("verb", True):
1029                            print("X Connection aborted.")
1030                        if kwargs.get("debug", 0) > 0:
1031                            traceback.print_exc()
1032                        continue
1033                    clients.append((atmt_server, clientsocket))
1034                    # start atmt
1035                    atmt_server.runbg()
1036                    # housekeeping
1037                    for atmt, clientsocket in clients:
1038                        if not atmt.isrunning():
1039                            atmt.destroy()
1040            except KeyboardInterrupt:
1041                print("X Exiting.")
1042                ssock.shutdown(socket.SHUT_RDWR)
1043            except OSError:
1044                print("X Server closed.")
1045                if kwargs.get("debug", 0) > 0:
1046                    traceback.print_exc()
1047            finally:
1048                for atmt, clientsocket in clients:
1049                    try:
1050                        atmt.forcestop(wait=False)
1051                        atmt.destroy()
1052                    except Exception:
1053                        pass
1054                    try:
1055                        clientsocket.shutdown(socket.SHUT_RDWR)
1056                        clientsocket.close()
1057                    except Exception:
1058                        pass
1059                ssock.close()
1060        if bg:
1061            # Background
1062            threading.Thread(target=_run).start()
1063            return ssock
1064        else:
1065            # Non-background
1066            _run()
1067            return None
1068
1069    def master_filter(self, pkt):
1070        # type: (Packet) -> bool
1071        return True
1072
1073    def my_send(self, pkt, **kwargs):
1074        # type: (Packet, **Any) -> None
1075        if not self.send_sock:
1076            raise ValueError("send_sock is None !")
1077        self.send_sock.send(pkt, **kwargs)
1078
1079    def update_sock(self, sock):
1080        # type: (SuperSocket) -> None
1081        """
1082        Update the socket used by the automata.
1083        Typically used in an eof event to reconnect.
1084        """
1085        self.sock = sock
1086        if self.listen_sock is not None:
1087            self.listen_sock = self.sock
1088        if self.send_sock:
1089            self.send_sock = self.sock
1090
1091    def timer_by_name(self, name):
1092        # type: (str) -> Optional[Timer]
1093        for _, timers in self.timeout.items():
1094            for timer in timers:  # type: Timer
1095                if timer._func.atmt_condname == name:
1096                    return timer
1097        return None
1098
1099    # Utility classes and exceptions
1100    class _IO_fdwrapper:
1101        def __init__(self,
1102                     rd,  # type: Union[int, ObjectPipe[bytes], None]
1103                     wr  # type: Union[int, ObjectPipe[bytes], None]
1104                     ):
1105            # type: (...) -> None
1106            self.rd = rd
1107            self.wr = wr
1108            if isinstance(self.rd, socket.socket):
1109                self.__selectable_force_select__ = True
1110
1111        def fileno(self):
1112            # type: () -> int
1113            if isinstance(self.rd, int):
1114                return self.rd
1115            elif self.rd:
1116                return self.rd.fileno()
1117            return 0
1118
1119        def read(self, n=65535):
1120            # type: (int) -> Optional[bytes]
1121            if isinstance(self.rd, int):
1122                return os.read(self.rd, n)
1123            elif self.rd:
1124                return self.rd.recv(n)
1125            return None
1126
1127        def write(self, msg):
1128            # type: (bytes) -> int
1129            if isinstance(self.wr, int):
1130                return os.write(self.wr, msg)
1131            elif self.wr:
1132                return self.wr.send(msg)
1133            return 0
1134
1135        def recv(self, n=65535):
1136            # type: (int) -> Optional[bytes]
1137            return self.read(n)
1138
1139        def send(self, msg):
1140            # type: (bytes) -> int
1141            return self.write(msg)
1142
1143    class _IO_mixer:
1144        def __init__(self,
1145                     rd,  # type: ObjectPipe[Any]
1146                     wr,  # type: ObjectPipe[Any]
1147                     ):
1148            # type: (...) -> None
1149            self.rd = rd
1150            self.wr = wr
1151
1152        def fileno(self):
1153            # type: () -> Any
1154            if isinstance(self.rd, ObjectPipe):
1155                return self.rd.fileno()
1156            return self.rd
1157
1158        def recv(self, n=None):
1159            # type: (Optional[int]) -> Any
1160            return self.rd.recv(n)
1161
1162        def read(self, n=None):
1163            # type: (Optional[int]) -> Any
1164            return self.recv(n)
1165
1166        def send(self, msg):
1167            # type: (str) -> int
1168            return self.wr.send(msg)
1169
1170        def write(self, msg):
1171            # type: (str) -> int
1172            return self.send(msg)
1173
1174    class AutomatonException(Exception):
1175        def __init__(self, msg, state=None, result=None):
1176            # type: (str, Optional[Message], Optional[str]) -> None
1177            Exception.__init__(self, msg)
1178            self.state = state
1179            self.result = result
1180
1181    class AutomatonError(AutomatonException):
1182        pass
1183
1184    class ErrorState(AutomatonException):
1185        pass
1186
1187    class Stuck(AutomatonException):
1188        pass
1189
1190    class AutomatonStopped(AutomatonException):
1191        pass
1192
1193    class Breakpoint(AutomatonStopped):
1194        pass
1195
1196    class Singlestep(AutomatonStopped):
1197        pass
1198
1199    class InterceptionPoint(AutomatonStopped):
1200        def __init__(self, msg, state=None, result=None, packet=None):
1201            # type: (str, Optional[Message], Optional[str], Optional[Packet]) -> None
1202            Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
1203            self.packet = packet
1204
1205    class CommandMessage(AutomatonException):
1206        pass
1207
1208    # Services
1209    def debug(self, lvl, msg):
1210        # type: (int, str) -> None
1211        if self.debug_level >= lvl:
1212            log_runtime.debug(msg)
1213
1214    def isrunning(self):
1215        # type: () -> bool
1216        return self.started.locked()
1217
1218    def send(self, pkt, **kwargs):
1219        # type: (Packet, **Any) -> None
1220        if self.state.state in self.interception_points:
1221            self.debug(3, "INTERCEPT: packet intercepted: %s" % pkt.summary())
1222            self.intercepted_packet = pkt
1223            self.cmdout.send(
1224                Message(type=_ATMT_Command.INTERCEPT,
1225                        state=self.state, pkt=pkt)
1226            )
1227            cmd = self.cmdin.recv()
1228            if not cmd:
1229                self.debug(3, "CANCELLED")
1230                return
1231            self.intercepted_packet = None
1232            if cmd.type == _ATMT_Command.REJECT:
1233                self.debug(3, "INTERCEPT: packet rejected")
1234                return
1235            elif cmd.type == _ATMT_Command.REPLACE:
1236                pkt = cmd.pkt
1237                self.debug(3, "INTERCEPT: packet replaced by: %s" % pkt.summary())  # noqa: E501
1238            elif cmd.type == _ATMT_Command.ACCEPT:
1239                self.debug(3, "INTERCEPT: packet accepted")
1240            else:
1241                raise self.AutomatonError("INTERCEPT: unknown verdict: %r" % cmd.type)  # noqa: E501
1242        self.my_send(pkt, **kwargs)
1243        self.debug(3, "SENT : %s" % pkt.summary())
1244
1245        if self.store_packets:
1246            self.packets.append(pkt.copy())
1247
1248    def __iter__(self):
1249        # type: () -> Automaton
1250        return self
1251
1252    def __del__(self):
1253        # type: () -> None
1254        self.destroy()
1255
1256    def _run_condition(self, cond, *args, **kargs):
1257        # type: (_StateWrapper, Any, Any) -> None
1258        try:
1259            self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))  # noqa: E501
1260            cond(self, *args, **kargs)
1261        except ATMT.NewStateRequested as state_req:
1262            self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))  # noqa: E501
1263            if cond.atmt_type == ATMT.RECV:
1264                if self.store_packets:
1265                    self.packets.append(args[0])
1266            for action in self.actions[cond.atmt_condname]:
1267                self.debug(2, "   + Running action [%s]" % action.__name__)
1268                action(self, *state_req.action_args, **state_req.action_kargs)
1269            raise
1270        except Exception as e:
1271            self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))  # noqa: E501
1272            raise
1273        else:
1274            self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))  # noqa: E501
1275
1276    def _do_start(self, *args, **kargs):
1277        # type: (Any, Any) -> None
1278        ready = threading.Event()
1279        _t = threading.Thread(
1280            target=self._do_control,
1281            args=(ready,) + (args),
1282            kwargs=kargs,
1283            name="scapy.automaton _do_start"
1284        )
1285        _t.daemon = True
1286        _t.start()
1287        ready.wait()
1288
1289    def _do_control(self, ready, *args, **kargs):
1290        # type: (threading.Event, Any, Any) -> None
1291        with self.started:
1292            self.threadid = threading.current_thread().ident
1293            if self.threadid is None:
1294                self.threadid = 0
1295
1296            # Update default parameters
1297            a = args + self.init_args[len(args):]
1298            k = self.init_kargs.copy()
1299            k.update(kargs)
1300            self.parse_args(*a, **k)
1301
1302            # Start the automaton
1303            self.state = self.initial_states[0](self)
1304            self.send_sock = self.sock or self.send_sock_class(**self.socket_kargs)
1305            if self.recv_conditions:
1306                # Only start a receiving socket if we have at least one recv_conditions
1307                self.listen_sock = self.sock or self.recv_sock_class(**self.socket_kargs)  # noqa: E501
1308            self.packets = PacketList(name="session[%s]" % self.__class__.__name__)
1309
1310            singlestep = True
1311            iterator = self._do_iter()
1312            self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
1313            # Sync threads
1314            ready.set()
1315            try:
1316                while True:
1317                    c = self.cmdin.recv()
1318                    if c is None:
1319                        return None
1320                    self.debug(5, "Received command %s" % c.type)
1321                    if c.type == _ATMT_Command.RUN:
1322                        singlestep = False
1323                    elif c.type == _ATMT_Command.NEXT:
1324                        singlestep = True
1325                    elif c.type == _ATMT_Command.FREEZE:
1326                        continue
1327                    elif c.type == _ATMT_Command.STOP:
1328                        if self.stop_state:
1329                            # There is a stop state
1330                            self.state = self.stop_state()
1331                            iterator = self._do_iter()
1332                        else:
1333                            # Act as FORCESTOP
1334                            break
1335                    elif c.type == _ATMT_Command.FORCESTOP:
1336                        break
1337                    while True:
1338                        state = next(iterator)
1339                        if isinstance(state, self.CommandMessage):
1340                            break
1341                        elif isinstance(state, self.Breakpoint):
1342                            c = Message(type=_ATMT_Command.BREAKPOINT, state=state)  # noqa: E501
1343                            self.cmdout.send(c)
1344                            break
1345                        if singlestep:
1346                            c = Message(type=_ATMT_Command.SINGLESTEP, state=state)  # noqa: E501
1347                            self.cmdout.send(c)
1348                            break
1349            except (StopIteration, RuntimeError):
1350                c = Message(type=_ATMT_Command.END,
1351                            result=self.final_state_output)
1352                self.cmdout.send(c)
1353            except Exception as e:
1354                exc_info = sys.exc_info()
1355                self.debug(3, "Transferring exception from tid=%i:\n%s" % (self.threadid, "".join(traceback.format_exception(*exc_info))))  # noqa: E501
1356                m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info)  # noqa: E501
1357                self.cmdout.send(m)
1358            self.debug(3, "Stopping control thread (tid=%i)" % self.threadid)
1359            self.threadid = None
1360            if self.listen_sock:
1361                self.listen_sock.close()
1362            if self.send_sock:
1363                self.send_sock.close()
1364
1365    def _do_iter(self):
1366        # type: () -> Iterator[Union[Automaton.AutomatonException, Automaton.AutomatonStopped, ATMT.NewStateRequested, None]] # noqa: E501
1367        while True:
1368            try:
1369                self.debug(1, "## state=[%s]" % self.state.state)
1370
1371                # Entering a new state. First, call new state function
1372                if self.state.state in self.breakpoints and self.state.state != self.breakpointed:  # noqa: E501
1373                    self.breakpointed = self.state.state
1374                    yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,  # noqa: E501
1375                                          state=self.state.state)
1376                self.breakpointed = None
1377                state_output = self.state.run()
1378                if self.state.error:
1379                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output),  # noqa: E501
1380                                          result=state_output, state=self.state.state)  # noqa: E501
1381                if self.state.final:
1382                    self.final_state_output = state_output
1383                    return
1384
1385                if state_output is None:
1386                    state_output = ()
1387                elif not isinstance(state_output, list):
1388                    state_output = state_output,
1389
1390                timers = self.timeout[self.state.state]
1391                # If there are commandMessage, we should skip immediate
1392                # conditions.
1393                if not select_objects([self.cmdin], 0):
1394                    # Then check immediate conditions
1395                    for cond in self.conditions[self.state.state]:
1396                        self._run_condition(cond, *state_output)
1397
1398                    # If still there and no conditions left, we are stuck!
1399                    if (len(self.recv_conditions[self.state.state]) == 0 and
1400                        len(self.ioevents[self.state.state]) == 0 and
1401                            timers.count() == 0):
1402                        raise self.Stuck("stuck in [%s]" % self.state.state,
1403                                         state=self.state.state,
1404                                         result=state_output)
1405
1406                # Finally listen and pay attention to timeouts
1407                timers.reset()
1408                time_previous = time.time()
1409
1410                fds = [self.cmdin]  # type: List[Union[SuperSocket, ObjectPipe[Any]]]
1411                select_func = select_objects
1412                if self.listen_sock and self.recv_conditions[self.state.state]:
1413                    fds.append(self.listen_sock)
1414                    select_func = self.listen_sock.select  # type: ignore
1415                for ioev in self.ioevents[self.state.state]:
1416                    fds.append(self.ioin[ioev.atmt_ioname])
1417                while True:
1418                    time_current = time.time()
1419                    timers.decrement(time_current - time_previous)
1420                    time_previous = time_current
1421                    for timer in timers.expired():
1422                        self._run_condition(timer._func, *state_output)
1423                    remain = timers.until_next()
1424
1425                    self.debug(5, "Select on %r" % fds)
1426                    r = select_func(fds, remain)
1427                    self.debug(5, "Selected %r" % r)
1428                    for fd in r:
1429                        self.debug(5, "Looking at %r" % fd)
1430                        if fd == self.cmdin:
1431                            yield self.CommandMessage("Received command message")  # noqa: E501
1432                        elif fd == self.listen_sock:
1433                            try:
1434                                pkt = self.listen_sock.recv()
1435                            except EOFError:
1436                                # Socket was closed abruptly. This will likely only
1437                                # ever happen when a client socket is passed to the
1438                                # automaton (not the case when the automaton is
1439                                # listening on a promiscuous conf.L2sniff)
1440                                self.listen_sock.close()
1441                                # False so that it is still reset by update_sock
1442                                self.listen_sock = False  # type: ignore
1443                                fds.remove(fd)
1444                                if self.state.state in self.eofs:
1445                                    # There is an eof state
1446                                    eof = self.eofs[self.state.state]
1447                                    self.debug(2, "Condition EOF [%s] taken" % eof.__name__)  # noqa: E501
1448                                    raise self.eofs[self.state.state](self)
1449                                else:
1450                                    # There isn't. Therefore, it's a closing condition.
1451                                    raise EOFError("Socket ended arbruptly.")
1452                            if pkt is not None:
1453                                if self.master_filter(pkt):
1454                                    self.debug(3, "RECVD: %s" % pkt.summary())  # noqa: E501
1455                                    for rcvcond in self.recv_conditions[self.state.state]:  # noqa: E501
1456                                        self._run_condition(rcvcond, pkt, *state_output)  # noqa: E501
1457                                else:
1458                                    self.debug(4, "FILTR: %s" % pkt.summary())  # noqa: E501
1459                        else:
1460                            self.debug(3, "IOEVENT on %s" % fd.ioname)
1461                            for ioevt in self.ioevents[self.state.state]:
1462                                if ioevt.atmt_ioname == fd.ioname:
1463                                    self._run_condition(ioevt, fd, *state_output)  # noqa: E501
1464
1465            except ATMT.NewStateRequested as state_req:
1466                self.debug(2, "switching from [%s] to [%s]" % (self.state.state, state_req.state))  # noqa: E501
1467                self.state = state_req
1468                yield state_req
1469
1470    def __repr__(self):
1471        # type: () -> str
1472        return "<Automaton %s [%s]>" % (
1473            self.__class__.__name__,
1474            ["HALTED", "RUNNING"][self.isrunning()]
1475        )
1476
1477    # Public API
1478    def add_interception_points(self, *ipts):
1479        # type: (Any) -> None
1480        for ipt in ipts:
1481            if hasattr(ipt, "atmt_state"):
1482                ipt = ipt.atmt_state
1483            self.interception_points.add(ipt)
1484
1485    def remove_interception_points(self, *ipts):
1486        # type: (Any) -> None
1487        for ipt in ipts:
1488            if hasattr(ipt, "atmt_state"):
1489                ipt = ipt.atmt_state
1490            self.interception_points.discard(ipt)
1491
1492    def add_breakpoints(self, *bps):
1493        # type: (Any) -> None
1494        for bp in bps:
1495            if hasattr(bp, "atmt_state"):
1496                bp = bp.atmt_state
1497            self.breakpoints.add(bp)
1498
1499    def remove_breakpoints(self, *bps):
1500        # type: (Any) -> None
1501        for bp in bps:
1502            if hasattr(bp, "atmt_state"):
1503                bp = bp.atmt_state
1504            self.breakpoints.discard(bp)
1505
1506    def start(self, *args, **kargs):
1507        # type: (Any, Any) -> None
1508        if self.isrunning():
1509            raise ValueError("Already started")
1510        # Start the control thread
1511        self._do_start(*args, **kargs)
1512
1513    def run(self,
1514            resume=None,    # type: Optional[Message]
1515            wait=True       # type: Optional[bool]
1516            ):
1517        # type: (...) -> Any
1518        if resume is None:
1519            resume = Message(type=_ATMT_Command.RUN)
1520        self.cmdin.send(resume)
1521        if wait:
1522            try:
1523                c = self.cmdout.recv()
1524                if c is None:
1525                    return None
1526            except KeyboardInterrupt:
1527                self.cmdin.send(Message(type=_ATMT_Command.FREEZE))
1528                return None
1529            if c.type == _ATMT_Command.END:
1530                return c.result
1531            elif c.type == _ATMT_Command.INTERCEPT:
1532                raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)  # noqa: E501
1533            elif c.type == _ATMT_Command.SINGLESTEP:
1534                raise self.Singlestep("singlestep state=[%s]" % c.state.state, state=c.state.state)  # noqa: E501
1535            elif c.type == _ATMT_Command.BREAKPOINT:
1536                raise self.Breakpoint("breakpoint triggered on state [%s]" % c.state.state, state=c.state.state)  # noqa: E501
1537            elif c.type == _ATMT_Command.EXCEPTION:
1538                # this code comes from the `six` module (`.reraise()`)
1539                # to raise an exception with specified exc_info.
1540                value = c.exc_info[0]() if c.exc_info[1] is None else c.exc_info[1]  # type: ignore  # noqa: E501
1541                if value.__traceback__ is not c.exc_info[2]:
1542                    raise value.with_traceback(c.exc_info[2])
1543                raise value
1544        return None
1545
1546    def runbg(self, resume=None, wait=False):
1547        # type: (Optional[Message], Optional[bool]) -> None
1548        self.run(resume, wait)
1549
1550    def __next__(self):
1551        # type: () -> Any
1552        return self.run(resume=Message(type=_ATMT_Command.NEXT))
1553
1554    def _flush_inout(self):
1555        # type: () -> None
1556        # Flush command pipes
1557        for cmd in [self.cmdin, self.cmdout]:
1558            cmd.clear()
1559
1560    def destroy(self):
1561        # type: () -> None
1562        """
1563        Destroys a stopped Automaton: this cleanups all opened file descriptors.
1564        Required on PyPy for instance where the garbage collector behaves differently.
1565        """
1566        if not hasattr(self, "started"):
1567            return  # was never started.
1568        if self.isrunning():
1569            raise ValueError("Can't close running Automaton ! Call stop() beforehand")
1570        # Close command pipes
1571        self.cmdin.close()
1572        self.cmdout.close()
1573        self._flush_inout()
1574        # Close opened ioins/ioouts
1575        for i in itertools.chain(self.ioin.values(), self.ioout.values()):
1576            if isinstance(i, ObjectPipe):
1577                i.close()
1578
1579    def stop(self, wait=True):
1580        # type: (bool) -> None
1581        try:
1582            self.cmdin.send(Message(type=_ATMT_Command.STOP))
1583        except OSError:
1584            pass
1585        if wait:
1586            with self.started:
1587                self._flush_inout()
1588
1589    def forcestop(self, wait=True):
1590        # type: (bool) -> None
1591        try:
1592            self.cmdin.send(Message(type=_ATMT_Command.FORCESTOP))
1593        except OSError:
1594            pass
1595        if wait:
1596            with self.started:
1597                self._flush_inout()
1598
1599    def restart(self, *args, **kargs):
1600        # type: (Any, Any) -> None
1601        self.stop()
1602        self.start(*args, **kargs)
1603
1604    def accept_packet(self,
1605                      pkt=None,     # type: Optional[Packet]
1606                      wait=False    # type: Optional[bool]
1607                      ):
1608        # type: (...) -> Any
1609        rsm = Message()
1610        if pkt is None:
1611            rsm.type = _ATMT_Command.ACCEPT
1612        else:
1613            rsm.type = _ATMT_Command.REPLACE
1614            rsm.pkt = pkt
1615        return self.run(resume=rsm, wait=wait)
1616
1617    def reject_packet(self,
1618                      wait=False    # type: Optional[bool]
1619                      ):
1620        # type: (...) -> Any
1621        rsm = Message(type=_ATMT_Command.REJECT)
1622        return self.run(resume=rsm, wait=wait)
1623