• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1## This file is part of Scapy
2## See http://www.secdev.org/projects/scapy for more informations
3## Copyright (C) Philippe Biondi <phil@secdev.org>
4## Copyright (C) Gabriel Potter <gabriel@potter.fr>
5## This program is published under a GPLv2 license
6
7"""
8Automata with states, transitions and actions.
9"""
10
11from __future__ import absolute_import
12import types,itertools,time,os,sys,socket,traceback
13from select import select
14from collections import deque
15import threading
16from scapy.config import conf
17from scapy.utils import do_graph
18from scapy.error import log_interactive
19from scapy.plist import PacketList
20from scapy.data import MTU
21from scapy.supersocket import SuperSocket
22from scapy.consts import WINDOWS
23from scapy.compat import *
24import scapy.modules.six as six
25
26try:
27    import thread
28except ImportError:
29    THREAD_EXCEPTION = RuntimeError
30else:
31    THREAD_EXCEPTION = thread.error
32
33if WINDOWS:
34    from scapy.error import Scapy_Exception
35    recv_error = Scapy_Exception
36else:
37    recv_error = ()
38
39""" In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionnality
40# Passive way: using no-ressources locks
41               +---------+             +---------------+      +-------------------------+
42               |  Start  +------------->Select_objects +----->+Linux: call select.select|
43               +---------+             |(select.select)|      +-------------------------+
44                                       +-------+-------+
45                                               |
46                                          +----v----+               +--------+
47                                          | Windows |               |Time Out+----------------------------------+
48                                          +----+----+               +----+---+                                  |
49                                               |                         ^                                      |
50      Event                                    |                         |                                      |
51        +                                      |                         |                                      |
52        |                              +-------v-------+                 |                                      |
53        |                       +------+Selectable Sel.+-----+-----------------+-----------+                    |
54        |                       |      +-------+-------+     |           |     |           v              +-----v-----+
55+-------v----------+            |              |             |           |     |        Passive lock<-----+release_all<------+
56|Data added to list|       +----v-----+  +-----v-----+  +----v-----+     v     v            +             +-----------+      |
57+--------+---------+       |Selectable|  |Selectable |  |Selectable|   ............         |                                |
58         |                 +----+-----+  +-----------+  +----------+                        |                                |
59         |                      v                                                           |                                |
60         v                 +----+------+   +------------------+               +-------------v-------------------+            |
61   +-----+------+          |wait_return+-->+  check_recv:     |               |                                 |            |
62   |call_release|          +----+------+   |If data is in list|               |  END state: selectable returned |        +---+--------+
63   +-----+--------              v          +-------+----------+               |                                 |        | exit door  |
64         |                    else                 |                          +---------------------------------+        +---+--------+
65         |                      +                  |                                                                         |
66         |                 +----v-------+          |                                                                         |
67         +--------->free -->Passive lock|          |                                                                         |
68                           +----+-------+          |                                                                         |
69                                |                  |                                                                         |
70                                |                  v                                                                         |
71                                +------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+
72"""
73
74class SelectableObject:
75    """DEV: to implement one of those, you need to add 2 things to your object:
76    - add "check_recv" function
77    - call "self.call_release" once you are ready to be read
78
79    You can set the __selectable_force_select__ to True in the class, if you want to
80    force the handler to use fileno(). This may only be useable on sockets created using
81    the builtin socket API."""
82    __selectable_force_select__ = False
83    def check_recv(self):
84        """DEV: will be called only once (at beginning) to check if the object is ready."""
85        raise OSError("This method must be overwriten.")
86
87    def _wait_non_ressources(self, callback):
88        """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback"""
89        self.trigger = threading.Lock()
90        self.was_ended = False
91        self.trigger.acquire()
92        self.trigger.acquire()
93        if not self.was_ended:
94            callback(self)
95
96    def wait_return(self, callback):
97        """Entry point of SelectableObject: register the callback"""
98        if self.check_recv():
99            return callback(self)
100        _t = threading.Thread(target=self._wait_non_ressources, args=(callback,))
101        _t.setDaemon(True)
102        _t.start()
103
104    def call_release(self, arborted=False):
105        """DEV: Must be call when the object becomes ready to read.
106           Relesases the lock of _wait_non_ressources"""
107        self.was_ended = arborted
108        try:
109            self.trigger.release()
110        except (THREAD_EXCEPTION, AttributeError):
111            pass
112
113class SelectableSelector(object):
114    """
115    Select SelectableObject objects.
116
117    inputs: objects to process
118    remain: timeout. If 0, return [].
119    customTypes: types of the objects that have the check_recv function.
120    """
121    def _release_all(self):
122        """Releases all locks to kill all threads"""
123        for i in self.inputs:
124            i.call_release(True)
125        self.available_lock.release()
126
127    def _timeout_thread(self, remain):
128        """Timeout before releasing every thing, if nothing was returned"""
129        time.sleep(remain)
130        if not self._ended:
131            self._ended = True
132            self._release_all()
133
134    def _exit_door(self, _input):
135        """This function is passed to each SelectableObject as a callback
136        The SelectableObjects have to call it once there are ready"""
137        self.results.append(_input)
138        if self._ended:
139            return
140        self._ended = True
141        self._release_all()
142
143    def __init__(self, inputs, remain):
144        self.results = []
145        self.inputs = list(inputs)
146        self.remain = remain
147        self.available_lock = threading.Lock()
148        self.available_lock.acquire()
149        self._ended = False
150
151    def process(self):
152        """Entry point of SelectableSelector"""
153        if WINDOWS:
154            select_inputs = []
155            for i in self.inputs:
156                if not isinstance(i, SelectableObject):
157                    warning("Unknown ignored object type: %s", type(i))
158                elif i.__selectable_force_select__:
159                    # Then use select.select
160                    select_inputs.append(i)
161                elif not self.remain and i.check_recv():
162                    self.results.append(i)
163                else:
164                    i.wait_return(self._exit_door)
165            if select_inputs:
166                # Use default select function
167                self.results.extend(select(select_inputs, [], [], self.remain)[0])
168            if not self.remain:
169                return self.results
170
171            threading.Thread(target=self._timeout_thread, args=(self.remain,)).start()
172            if not self._ended:
173                self.available_lock.acquire()
174            return self.results
175        else:
176            r,_,_ = select(self.inputs,[],[],self.remain)
177            return r
178
179def select_objects(inputs, remain):
180    """
181    Select SelectableObject objects. Same than:
182        select.select([inputs], [], [], remain)
183    But also works on Windows, only on SelectableObject.
184
185    inputs: objects to process
186    remain: timeout. If 0, return [].
187    """
188    handler = SelectableSelector(inputs, remain)
189    return handler.process()
190
191class ObjectPipe(SelectableObject):
192    def __init__(self):
193        self.rd,self.wr = os.pipe()
194        self.queue = deque()
195    def fileno(self):
196        return self.rd
197    def check_recv(self):
198        return len(self.queue) > 0
199    def send(self, obj):
200        self.queue.append(obj)
201        os.write(self.wr,b"X")
202        self.call_release()
203    def write(self, obj):
204        self.send(obj)
205    def recv(self, n=0):
206        os.read(self.rd, 1)
207        return self.queue.popleft()
208    def read(self, n=0):
209        return self.recv(n)
210
211class Message:
212    def __init__(self, **args):
213        self.__dict__.update(args)
214    def __repr__(self):
215        return "<Message %s>" % " ".join("%s=%r"%(k,v)
216                                         for (k,v) in six.iteritems(self.__dict__)
217                                         if not k.startswith("_"))
218
219class _instance_state:
220    def __init__(self, instance):
221        self.__self__ = instance.__self__
222        self.__func__ = instance.__func__
223        self.__self__.__class__ = instance.__self__.__class__
224    def __getattr__(self, attr):
225        return getattr(self.__func__, attr)
226    def __call__(self, *args, **kargs):
227        return self.__func__(self.__self__, *args, **kargs)
228    def breaks(self):
229        return self.__self__.add_breakpoints(self.__func__)
230    def intercepts(self):
231        return self.__self__.add_interception_points(self.__func__)
232    def unbreaks(self):
233        return self.__self__.remove_breakpoints(self.__func__)
234    def unintercepts(self):
235        return self.__self__.remove_interception_points(self.__func__)
236
237
238##############
239## Automata ##
240##############
241
242class ATMT:
243    STATE = "State"
244    ACTION = "Action"
245    CONDITION = "Condition"
246    RECV = "Receive condition"
247    TIMEOUT = "Timeout condition"
248    IOEVENT = "I/O event"
249
250    class NewStateRequested(Exception):
251        def __init__(self, state_func, automaton, *args, **kargs):
252            self.func = state_func
253            self.state = state_func.atmt_state
254            self.initial = state_func.atmt_initial
255            self.error = state_func.atmt_error
256            self.final = state_func.atmt_final
257            Exception.__init__(self, "Request state [%s]" % self.state)
258            self.automaton = automaton
259            self.args = args
260            self.kargs = kargs
261            self.action_parameters() # init action parameters
262        def action_parameters(self, *args, **kargs):
263            self.action_args = args
264            self.action_kargs = kargs
265            return self
266        def run(self):
267            return self.func(self.automaton, *self.args, **self.kargs)
268        def __repr__(self):
269            return "NewStateRequested(%s)" % self.state
270
271    @staticmethod
272    def state(initial=0,final=0,error=0):
273        def deco(f,initial=initial, final=final):
274            f.atmt_type = ATMT.STATE
275            f.atmt_state = f.__name__
276            f.atmt_initial = initial
277            f.atmt_final = final
278            f.atmt_error = error
279            def state_wrapper(self, *args, **kargs):
280                return ATMT.NewStateRequested(f, self, *args, **kargs)
281
282            state_wrapper.__name__ = "%s_wrapper" % f.__name__
283            state_wrapper.atmt_type = ATMT.STATE
284            state_wrapper.atmt_state = f.__name__
285            state_wrapper.atmt_initial = initial
286            state_wrapper.atmt_final = final
287            state_wrapper.atmt_error = error
288            state_wrapper.atmt_origfunc = f
289            return state_wrapper
290        return deco
291    @staticmethod
292    def action(cond, prio=0):
293        def deco(f,cond=cond):
294            if not hasattr(f,"atmt_type"):
295                f.atmt_cond = {}
296            f.atmt_type = ATMT.ACTION
297            f.atmt_cond[cond.atmt_condname] = prio
298            return f
299        return deco
300    @staticmethod
301    def condition(state, prio=0):
302        def deco(f, state=state):
303            f.atmt_type = ATMT.CONDITION
304            f.atmt_state = state.atmt_state
305            f.atmt_condname = f.__name__
306            f.atmt_prio = prio
307            return f
308        return deco
309    @staticmethod
310    def receive_condition(state, prio=0):
311        def deco(f, state=state):
312            f.atmt_type = ATMT.RECV
313            f.atmt_state = state.atmt_state
314            f.atmt_condname = f.__name__
315            f.atmt_prio = prio
316            return f
317        return deco
318    @staticmethod
319    def ioevent(state, name, prio=0, as_supersocket=None):
320        def deco(f, state=state):
321            f.atmt_type = ATMT.IOEVENT
322            f.atmt_state = state.atmt_state
323            f.atmt_condname = f.__name__
324            f.atmt_ioname = name
325            f.atmt_prio = prio
326            f.atmt_as_supersocket = as_supersocket
327            return f
328        return deco
329    @staticmethod
330    def timeout(state, timeout):
331        def deco(f, state=state, timeout=timeout):
332            f.atmt_type = ATMT.TIMEOUT
333            f.atmt_state = state.atmt_state
334            f.atmt_timeout = timeout
335            f.atmt_condname = f.__name__
336            return f
337        return deco
338
339class _ATMT_Command:
340    RUN = "RUN"
341    NEXT = "NEXT"
342    FREEZE = "FREEZE"
343    STOP = "STOP"
344    END = "END"
345    EXCEPTION = "EXCEPTION"
346    SINGLESTEP = "SINGLESTEP"
347    BREAKPOINT = "BREAKPOINT"
348    INTERCEPT = "INTERCEPT"
349    ACCEPT = "ACCEPT"
350    REPLACE = "REPLACE"
351    REJECT = "REJECT"
352
353class _ATMT_supersocket(SuperSocket):
354    def __init__(self, name, ioevent, automaton, proto, args, kargs):
355        self.name = name
356        self.ioevent = ioevent
357        self.proto = proto
358        self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
359        kargs["external_fd"] = {ioevent:self.spb}
360        self.atmt = automaton(*args, **kargs)
361        self.atmt.runbg()
362    def fileno(self):
363        return self.spa.fileno()
364    def send(self, s):
365        if not isinstance(s, bytes):
366            s = bytes(s)
367        return self.spa.send(s)
368    def recv(self, n=MTU):
369        try:
370            r = self.spa.recv(n)
371        except recv_error:
372            if not WINDOWS:
373                raise
374            return None
375        if self.proto is not None:
376            r = self.proto(r)
377        return r
378    def close(self):
379        pass
380
381class _ATMT_to_supersocket:
382    def __init__(self, name, ioevent, automaton):
383        self.name = name
384        self.ioevent = ioevent
385        self.automaton = automaton
386    def __call__(self, proto, *args, **kargs):
387        return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs)
388
389class Automaton_metaclass(type):
390    def __new__(cls, name, bases, dct):
391        cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
392        cls.states={}
393        cls.state = None
394        cls.recv_conditions={}
395        cls.conditions={}
396        cls.ioevents={}
397        cls.timeout={}
398        cls.actions={}
399        cls.initial_states=[]
400        cls.ionames = []
401        cls.iosupersockets = []
402
403        members = {}
404        classes = [cls]
405        while classes:
406            c = classes.pop(0) # order is important to avoid breaking method overloading
407            classes += list(c.__bases__)
408            for k,v in six.iteritems(c.__dict__):
409                if k not in members:
410                    members[k] = v
411
412        decorated = [v for v in six.itervalues(members)
413                     if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")]
414
415        for m in decorated:
416            if m.atmt_type == ATMT.STATE:
417                s = m.atmt_state
418                cls.states[s] = m
419                cls.recv_conditions[s]=[]
420                cls.ioevents[s]=[]
421                cls.conditions[s]=[]
422                cls.timeout[s]=[]
423                if m.atmt_initial:
424                    cls.initial_states.append(m)
425            elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
426                cls.actions[m.atmt_condname] = []
427
428        for m in decorated:
429            if m.atmt_type == ATMT.CONDITION:
430                cls.conditions[m.atmt_state].append(m)
431            elif m.atmt_type == ATMT.RECV:
432                cls.recv_conditions[m.atmt_state].append(m)
433            elif m.atmt_type == ATMT.IOEVENT:
434                cls.ioevents[m.atmt_state].append(m)
435                cls.ionames.append(m.atmt_ioname)
436                if m.atmt_as_supersocket is not None:
437                    cls.iosupersockets.append(m)
438            elif m.atmt_type == ATMT.TIMEOUT:
439                cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
440            elif m.atmt_type == ATMT.ACTION:
441                for c in m.atmt_cond:
442                    cls.actions[c].append(m)
443
444
445        for v in six.itervalues(cls.timeout):
446            v.sort(key=cmp_to_key(lambda t1_f1,t2_f2: cmp(t1_f1[0],t2_f2[0])))
447            v.append((None, None))
448        for v in itertools.chain(six.itervalues(cls.conditions),
449                                 six.itervalues(cls.recv_conditions),
450                                 six.itervalues(cls.ioevents)):
451            v.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio)))
452        for condname,actlst in six.iteritems(cls.actions):
453            actlst.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname])))
454
455        for ioev in cls.iosupersockets:
456            setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls))
457
458        return cls
459
460    def graph(self, **kargs):
461        s = 'digraph "%s" {\n'  % self.__class__.__name__
462
463        se = "" # Keep initial nodes at the begining for better rendering
464        for st in six.itervalues(self.states):
465            if st.atmt_initial:
466                se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
467            elif st.atmt_final:
468                se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
469            elif st.atmt_error:
470                se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
471        s += se
472
473        for st in six.itervalues(self.states):
474            for n in st.atmt_origfunc.__code__.co_names+st.atmt_origfunc.__code__.co_consts:
475                if n in self.states:
476                    s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
477
478
479        for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
480                      [("red",k,v) for k,v in self.recv_conditions.items()]+
481                      [("orange",k,v) for k,v in self.ioevents.items()]):
482            for f in v:
483                for n in f.__code__.co_names+f.__code__.co_consts:
484                    if n in self.states:
485                        l = f.atmt_condname
486                        for x in self.actions[f.atmt_condname]:
487                            l += "\\l>[%s]" % x.__name__
488                        s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
489        for k,v in six.iteritems(self.timeout):
490            for t,f in v:
491                if f is None:
492                    continue
493                for n in f.__code__.co_names+f.__code__.co_consts:
494                    if n in self.states:
495                        l = "%s/%.1fs" % (f.atmt_condname,t)
496                        for x in self.actions[f.atmt_condname]:
497                            l += "\\l>[%s]" % x.__name__
498                        s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
499        s += "}\n"
500        return do_graph(s, **kargs)
501
502class Automaton(six.with_metaclass(Automaton_metaclass)):
503    def parse_args(self, debug=0, store=1, **kargs):
504        self.debug_level=debug
505        self.socket_kargs = kargs
506        self.store_packets = store
507
508    def master_filter(self, pkt):
509        return True
510
511    def my_send(self, pkt):
512        self.send_sock.send(pkt)
513
514
515    ## Utility classes and exceptions
516    class _IO_fdwrapper(SelectableObject):
517        def __init__(self,rd,wr):
518            if WINDOWS:
519                # rd will be used for reading and sending
520                if isinstance(rd, ObjectPipe):
521                    self.rd = rd
522                else:
523                    raise OSError("On windows, only instances of ObjectPipe are externally available")
524            else:
525                if rd is not None and not isinstance(rd, int):
526                    rd = rd.fileno()
527                if wr is not None and not isinstance(wr, int):
528                    wr = wr.fileno()
529                self.rd = rd
530                self.wr = wr
531        def fileno(self):
532            return self.rd
533        def check_recv(self):
534            return self.rd.check_recv()
535        def read(self, n=65535):
536            if WINDOWS:
537                return self.rd.recv(n)
538            return os.read(self.rd, n)
539        def write(self, msg):
540            if WINDOWS:
541                self.rd.send(msg)
542                return self.call_release()
543            return os.write(self.wr,msg)
544        def recv(self, n=65535):
545            return self.read(n)
546        def send(self, msg):
547            return self.write(msg)
548
549    class _IO_mixer(SelectableObject):
550        def __init__(self,rd,wr):
551            self.rd = rd
552            self.wr = wr
553        def fileno(self):
554            if isinstance(self.rd, int):
555                return self.rd
556            return self.rd.fileno()
557        def check_recv(self):
558            return self.rd.check_recv()
559        def recv(self, n=None):
560            return self.rd.recv(n)
561        def read(self, n=None):
562            return self.recv(n)
563        def send(self, msg):
564            self.wr.send(msg)
565            return self.call_release()
566        def write(self, msg):
567            return self.send(msg)
568
569
570    class AutomatonException(Exception):
571        def __init__(self, msg, state=None, result=None):
572            Exception.__init__(self, msg)
573            self.state = state
574            self.result = result
575
576    class AutomatonError(AutomatonException):
577        pass
578    class ErrorState(AutomatonException):
579        pass
580    class Stuck(AutomatonException):
581        pass
582    class AutomatonStopped(AutomatonException):
583        pass
584
585    class Breakpoint(AutomatonStopped):
586        pass
587    class Singlestep(AutomatonStopped):
588        pass
589    class InterceptionPoint(AutomatonStopped):
590        def __init__(self, msg, state=None, result=None, packet=None):
591            Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
592            self.packet = packet
593
594    class CommandMessage(AutomatonException):
595        pass
596
597
598    ## Services
599    def debug(self, lvl, msg):
600        if self.debug_level >= lvl:
601            log_interactive.debug(msg)
602
603    def send(self, pkt):
604        if self.state.state in self.interception_points:
605            self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
606            self.intercepted_packet = pkt
607            cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
608            self.cmdout.send(cmd)
609            cmd = self.cmdin.recv()
610            self.intercepted_packet = None
611            if cmd.type == _ATMT_Command.REJECT:
612                self.debug(3,"INTERCEPT: packet rejected")
613                return
614            elif cmd.type == _ATMT_Command.REPLACE:
615                pkt = cmd.pkt
616                self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary())
617            elif cmd.type == _ATMT_Command.ACCEPT:
618                self.debug(3,"INTERCEPT: packet accepted")
619            else:
620                raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type)
621        self.my_send(pkt)
622        self.debug(3,"SENT : %s" % pkt.summary())
623
624        if self.store_packets:
625            self.packets.append(pkt.copy())
626
627
628    ## Internals
629    def __init__(self, *args, **kargs):
630        external_fd = kargs.pop("external_fd",{})
631        self.send_sock_class = kargs.pop("ll", conf.L3socket)
632        self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
633        self.started = threading.Lock()
634        self.threadid = None
635        self.breakpointed = None
636        self.breakpoints = set()
637        self.interception_points = set()
638        self.intercepted_packet = None
639        self.debug_level=0
640        self.init_args=args
641        self.init_kargs=kargs
642        self.io = type.__new__(type, "IOnamespace",(),{})
643        self.oi = type.__new__(type, "IOnamespace",(),{})
644        self.cmdin = ObjectPipe()
645        self.cmdout = ObjectPipe()
646        self.ioin = {}
647        self.ioout = {}
648        for n in self.ionames:
649            extfd = external_fd.get(n)
650            if not isinstance(extfd, tuple):
651                extfd = (extfd,extfd)
652            elif WINDOWS:
653                raise OSError("Tuples are not allowed as external_fd on windows")
654            ioin,ioout = extfd
655            if ioin is None:
656                ioin = ObjectPipe()
657            elif not isinstance(ioin, SelectableObject):
658                ioin = self._IO_fdwrapper(ioin,None)
659            if ioout is None:
660                ioout = ioin if WINDOWS else ObjectPipe()
661            elif not isinstance(ioout, SelectableObject):
662                ioout = self._IO_fdwrapper(None,ioout)
663
664            self.ioin[n] = ioin
665            self.ioout[n] = ioout
666            ioin.ioname = n
667            ioout.ioname = n
668            setattr(self.io, n, self._IO_mixer(ioout,ioin))
669            setattr(self.oi, n, self._IO_mixer(ioin,ioout))
670
671        for stname in self.states:
672            setattr(self, stname,
673                    _instance_state(getattr(self, stname)))
674
675        self.start()
676
677    def __iter__(self):
678        return self
679
680    def __del__(self):
681        self.stop()
682
683    def _run_condition(self, cond, *args, **kargs):
684        try:
685            self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))
686            cond(self,*args, **kargs)
687        except ATMT.NewStateRequested as state_req:
688            self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
689            if cond.atmt_type == ATMT.RECV:
690                if self.store_packets:
691                    self.packets.append(args[0])
692            for action in self.actions[cond.atmt_condname]:
693                self.debug(2, "   + Running action [%s]" % action.__name__)
694                action(self, *state_req.action_args, **state_req.action_kargs)
695            raise
696        except Exception as e:
697            self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))
698            raise
699        else:
700            self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
701
702    def _do_start(self, *args, **kargs):
703        ready = threading.Event()
704        _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs)
705        _t.setDaemon(True)
706        _t.start()
707        ready.wait()
708
709    def _do_control(self, ready, *args, **kargs):
710        with self.started:
711            self.threadid = threading.currentThread().ident
712
713            # Update default parameters
714            a = args+self.init_args[len(args):]
715            k = self.init_kargs.copy()
716            k.update(kargs)
717            self.parse_args(*a,**k)
718
719            # Start the automaton
720            self.state=self.initial_states[0](self)
721            self.send_sock = self.send_sock_class(**self.socket_kargs)
722            self.listen_sock = self.recv_sock_class(**self.socket_kargs)
723            self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
724
725            singlestep = True
726            iterator = self._do_iter()
727            self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
728            # Sync threads
729            ready.set()
730            try:
731                while True:
732                    c = self.cmdin.recv()
733                    self.debug(5, "Received command %s" % c.type)
734                    if c.type == _ATMT_Command.RUN:
735                        singlestep = False
736                    elif c.type == _ATMT_Command.NEXT:
737                        singlestep = True
738                    elif c.type == _ATMT_Command.FREEZE:
739                        continue
740                    elif c.type == _ATMT_Command.STOP:
741                        break
742                    while True:
743                        state = next(iterator)
744                        if isinstance(state, self.CommandMessage):
745                            break
746                        elif isinstance(state, self.Breakpoint):
747                            c = Message(type=_ATMT_Command.BREAKPOINT,state=state)
748                            self.cmdout.send(c)
749                            break
750                        if singlestep:
751                            c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
752                            self.cmdout.send(c)
753                            break
754            except StopIteration as e:
755                c = Message(type=_ATMT_Command.END, result=e.args[0])
756                self.cmdout.send(c)
757            except Exception as e:
758                exc_info = sys.exc_info()
759                self.debug(3, "Transfering exception from tid=%i:\n%s"% (self.threadid, traceback.format_exception(*exc_info)))
760                m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info)
761                self.cmdout.send(m)
762            self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
763            self.threadid = None
764
765    def _do_iter(self):
766        while True:
767            try:
768                self.debug(1, "## state=[%s]" % self.state.state)
769
770                # Entering a new state. First, call new state function
771                if self.state.state in self.breakpoints and self.state.state != self.breakpointed:
772                    self.breakpointed = self.state.state
773                    yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
774                                          state = self.state.state)
775                self.breakpointed = None
776                state_output = self.state.run()
777                if self.state.error:
778                    raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output),
779                                          result=state_output, state=self.state.state)
780                if self.state.final:
781                    raise StopIteration(state_output)
782
783                if state_output is None:
784                    state_output = ()
785                elif not isinstance(state_output, list):
786                    state_output = state_output,
787
788                # Then check immediate conditions
789                for cond in self.conditions[self.state.state]:
790                    self._run_condition(cond, *state_output)
791
792                # If still there and no conditions left, we are stuck!
793                if ( len(self.recv_conditions[self.state.state]) == 0 and
794                     len(self.ioevents[self.state.state]) == 0 and
795                     len(self.timeout[self.state.state]) == 1 ):
796                    raise self.Stuck("stuck in [%s]" % self.state.state,
797                                     state=self.state.state, result=state_output)
798
799                # Finally listen and pay attention to timeouts
800                expirations = iter(self.timeout[self.state.state])
801                next_timeout,timeout_func = next(expirations)
802                t0 = time.time()
803
804                fds = [self.cmdin]
805                if len(self.recv_conditions[self.state.state]) > 0:
806                    fds.append(self.listen_sock)
807                for ioev in self.ioevents[self.state.state]:
808                    fds.append(self.ioin[ioev.atmt_ioname])
809                while True:
810                    t = time.time()-t0
811                    if next_timeout is not None:
812                        if next_timeout <= t:
813                            self._run_condition(timeout_func, *state_output)
814                            next_timeout,timeout_func = next(expirations)
815                    if next_timeout is None:
816                        remain = None
817                    else:
818                        remain = next_timeout-t
819
820                    self.debug(5, "Select on %r" % fds)
821                    r = select_objects(fds, remain)
822                    self.debug(5, "Selected %r" % r)
823                    for fd in r:
824                        self.debug(5, "Looking at %r" % fd)
825                        if fd == self.cmdin:
826                            yield self.CommandMessage("Received command message")
827                        elif fd == self.listen_sock:
828                            try:
829                                pkt = self.listen_sock.recv(MTU)
830                            except recv_error:
831                                pass
832                            else:
833                                if pkt is not None:
834                                    if self.master_filter(pkt):
835                                        self.debug(3, "RECVD: %s" % pkt.summary())
836                                        for rcvcond in self.recv_conditions[self.state.state]:
837                                            self._run_condition(rcvcond, pkt, *state_output)
838                                    else:
839                                        self.debug(4, "FILTR: %s" % pkt.summary())
840                        else:
841                            self.debug(3, "IOEVENT on %s" % fd.ioname)
842                            for ioevt in self.ioevents[self.state.state]:
843                                if ioevt.atmt_ioname == fd.ioname:
844                                    self._run_condition(ioevt, fd, *state_output)
845
846            except ATMT.NewStateRequested as state_req:
847                self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
848                self.state = state_req
849                yield state_req
850
851    ## Public API
852    def add_interception_points(self, *ipts):
853        for ipt in ipts:
854            if hasattr(ipt,"atmt_state"):
855                ipt = ipt.atmt_state
856            self.interception_points.add(ipt)
857
858    def remove_interception_points(self, *ipts):
859        for ipt in ipts:
860            if hasattr(ipt,"atmt_state"):
861                ipt = ipt.atmt_state
862            self.interception_points.discard(ipt)
863
864    def add_breakpoints(self, *bps):
865        for bp in bps:
866            if hasattr(bp,"atmt_state"):
867                bp = bp.atmt_state
868            self.breakpoints.add(bp)
869
870    def remove_breakpoints(self, *bps):
871        for bp in bps:
872            if hasattr(bp,"atmt_state"):
873                bp = bp.atmt_state
874            self.breakpoints.discard(bp)
875
876    def start(self, *args, **kargs):
877        if not self.started.locked():
878            self._do_start(*args, **kargs)
879
880    def run(self, resume=None, wait=True):
881        if resume is None:
882            resume = Message(type = _ATMT_Command.RUN)
883        self.cmdin.send(resume)
884        if wait:
885            try:
886                c = self.cmdout.recv()
887            except KeyboardInterrupt:
888                self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
889                return
890            if c.type == _ATMT_Command.END:
891                return c.result
892            elif c.type == _ATMT_Command.INTERCEPT:
893                raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
894            elif c.type == _ATMT_Command.SINGLESTEP:
895                raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state)
896            elif c.type == _ATMT_Command.BREAKPOINT:
897                raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state)
898            elif c.type == _ATMT_Command.EXCEPTION:
899                six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2])
900
901    def runbg(self, resume=None, wait=False):
902        self.run(resume, wait)
903
904    def next(self):
905        return self.run(resume = Message(type=_ATMT_Command.NEXT))
906    __next__ = next
907
908    def stop(self):
909        self.cmdin.send(Message(type=_ATMT_Command.STOP))
910        with self.started:
911            # Flush command pipes
912            while True:
913                r = select_objects([self.cmdin, self.cmdout], 0)
914                if not r:
915                    break
916                for fd in r:
917                    fd.recv()
918
919    def restart(self, *args, **kargs):
920        self.stop()
921        self.start(*args, **kargs)
922
923    def accept_packet(self, pkt=None, wait=False):
924        rsm = Message()
925        if pkt is None:
926            rsm.type = _ATMT_Command.ACCEPT
927        else:
928            rsm.type = _ATMT_Command.REPLACE
929            rsm.pkt = pkt
930        return self.run(resume=rsm, wait=wait)
931
932    def reject_packet(self, wait=False):
933        rsm = Message(type = _ATMT_Command.REJECT)
934        return self.run(resume=rsm, wait=wait)
935
936
937
938