# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. import dpkt import os import select import struct import sys import threading import time import traceback class SimulatorError(Exception): "A Simulator generic error." class NullContext(object): """A context manager without any functionality.""" def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): return False # raises the exception if passed. class Simulator(object): """A TUN/TAP network interface simulator class. This class allows several implementations of different fake hosts to coexists on the same TUN/TAP interface. It will dispatch the same packet to each one of the registered hosts, providing some basic filtering to simplify these implementations. """ def __init__(self, iface): """Initialize the instance. @param tuntap.TunTap iface: the interface over which this interface runs. Should not be shared with other modules. """ self._iface = iface self._rules = [] # _events holds a lists of events that need to be fired for each # timestamp stored on the key. The event list is a list of callback # functions that will be called if the simulation reaches that # timestamp. This is used to fire time-based events. self._events = {} self._write_queue = [] # A pipe used to wake up the run() method from a diffent thread calling # stop(). See the stop() method for details. self._pipe_rd, self._pipe_wr = os.pipe() self._running = False # Lock object used for _events if multithreading is required. self._lock = NullContext() def __del__(self): os.close(self._pipe_rd) os.close(self._pipe_wr) def add_match(self, rule, callback): """Add a new match rule to the outbound traffic. This function adds a new rule that will be matched against each packet that the host sends through the interface and will call a callback if it matches. The rule can be specified in the following ways: * A python function that takes a packet as a single argument and returns True when the packet matches. * A dictionary of key=value pairs that all of them need to be matched. A pair matches when the packet has the provided chain of attributes and its value is equal to the provided value. For example, this will match any DNS traffic sent to the host 192.168.0.1: {"ip.dst": socket.inet_aton("192.168.0.1"), "ip.upd.dport": 53} @param rule: The rule description. @param callback: A callback function that receives the dpkt packet as the only argument. """ if not callable(callback): raise SimulatorError("|callback| must be a callable object.") if callable(rule): self._rules.append((rule, callback)) if isinstance(rule, dict): rule = dict(rule) # Makes a copy of the dict, but not the contents. self._rules.append((lambda p: self._dict_rule(rule, p), callback)) else: raise SimulatorError("Unknown rule format: %r" % rule) def add_timeout(self, timeout, callback): """Add a new callback function to be called after a timeout. This method schedules the given |callback| to be called after |timeout| seconds. The callback will be called at most once while the simulator is running (see the run() method). To have a repetitive event call again add_timeout() from the callback. @param timeout: The rule description. @param callback: A callback function that doesn't receive any argument. """ if not callable(callback): raise SimulatorError("|callback| must be a callable object.") timestamp = time.time() + timeout with self._lock: if timestamp not in self._events: self._events[timestamp] = [callback] else: self._events[timestamp].append(callback) def remove_timeout(self, callback): """Removes the every scheduled timeout call to the passed callback. When a callable object is passed to add_timeout() it is scheduled to be called once the timeout is reached. This method removes all the scheduled calls to that object. @param callback: The callable object passed to add_timeout(). @return: Wether the callback was found and removed at least once. """ removed = False for _ts, ev_list in self._events.iteritems(): try: while True: ev_list.remove(callback) removed = True except ValueError: pass return removed def _dict_rule(self, rules, pkt): """Returns wether a given packet matches a set of rules. The maching rules passed in |rules| need to be a dict() as described on the add_match() method. The packet |pkt| is any dpkt packet. """ for key, value in rules.iteritems(): p = pkt for member in key.split('.'): if not hasattr(p, member): return False p = getattr(p, member) if p != value: return False return True def write(self, pkt): """Writes a packet to the network interface. @param pkt: The dpkt.Packet to be received on the network interface. """ # Converts the dpkt packet to: flags, proto, buffer. self._write_queue.append(struct.pack("!HH", 0, pkt.type) + str(pkt)) def run(self, timeout=None, until=None): """Runs the Simulator. This method blocks the caller thread until the timeout is reached (if a timeout is passed), until stop() is called or until the function passed in until returns a True value (if a function is passed); whichever occurs first. stop() can be called from any other thread or from a callback called from this thread. @param timeout: The timeout in seconds. Can be a float value, or None for no timeout. @param until: A callable object called during the loop returning True when the loop should stop. """ if not self._iface.is_up(): raise SimulatorError("Interface is down.") stop_callback = None if timeout != None: # We use a newly created callable object to avoid remove another # scheduled call to self.stop. stop_callback = lambda: self.stop() self.add_timeout(timeout, stop_callback) self._running = True iface_fd = self._iface.fileno() # Check the until function. while not (until and until()): # The main purpose of this loop is to wait (block) until the next # event is required to be fired. There are four kinds of events: # * a packet is received. # * a packet waiting to be sent can now be sent. # * a time-based event needs to be fired. # * the simulator was stopped from a different thread. # To achieve this we use select.select() to wait simultaneously on # all those event sources. # Fires all the time-based events that need to be fired and computes # the timeout for the next event if there's one. timeout = None cur_time = time.time() with self._lock: if self._events: # Check events that should be fired. while self._events and min(self._events) <= cur_time: key = min(self._events) lst = self._events[key] del self._events[key] for callback in lst: callback() cur_time = time.time() # Check if there is an event to attend. Here we know that # min(self._events) > cur_time because the previous while # finished. if self._events: timeout = min(self._events) - cur_time # in seconds # Pool the until() function at least once a second. if timeout is None or timeout > 1.0: timeout = 1.0 # Compute the list of file descriptors that select.select() needs to # monitor to attend the required events. select() will return when # any of the following occurs: # * rlist: is possible to read from the interface or another # thread want's to wake up the simulator loop. # * wlist: is possible to write to network, if there's a packet # pending. # * xlist: an error on the network fd occured. Likely the TAP # interface was closed. # * timeout: The previously computed timeout was reached. rlist = iface_fd, self._pipe_rd wlist = tuple() if self._write_queue: wlist = iface_fd, xlist = iface_fd, rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout) if self._pipe_rd in rlist: msg = os.read(self._pipe_rd, 1) # stop() breaks the loop sending a '*'. if '*' in msg: break # Other messages are ignored. if xlist: break if iface_fd in wlist: self._iface.write(self._write_queue.pop(0)) # Attempt to send all the scheduled packets before reading more continue # Process the given packet: if iface_fd in rlist: raw = self._iface.read() flag, proto = struct.unpack("!HH", raw[:4]) pkt = dpkt.ethernet.Ethernet(raw[4:]) for rule, callback in self._rules: if rule(pkt): # Parse again the packet to allow callbacks to modify # it. callback(dpkt.ethernet.Ethernet(raw[4:])) if stop_callback: self.remove_timeout(stop_callback) self._running = False def stop(self): """Stops the run() method if it is running.""" os.write(self._pipe_wr, '*') class SimulatorThread(threading.Thread, Simulator): """A threaded version of the Simulator. This class exposses a similar interface as the Simulator class with the difference that it runs on its own thread. This exposes an extra method start() that should be called instead of Simulator.run(). start() will make the process run continuosly until stop() is called, after which the simulator can't be restarted. The methods used to add new matches can be called from any thread *before* the method start() is caller. After that point, only the callbacks, running from this thread, are allowed to create new matches and timeouts. Example: simu = SimulatorThread(tap_interface) simu.add_match({"ip.tcp.dport": 80}, some_callback) simu.start() time.sleep(100) simu.stop() simu.join() # Optional """ def __init__(self, iface, timeout=None): threading.Thread.__init__(self) Simulator.__init__(self, iface) self._timeout = timeout # We allow the same thread to acquire the lock more than once. This is # useful if a callback want's to add itself. self._lock = threading.RLock() self.error = None def run_on_simulator(self, callback): """Runs the given callback on the SimulatorThread thread. Before calling start() on the SimulatorThread, all the calls seting up the simulator are allowed, but once the thread is running, concurrency problems should be considered. This method runs the provided callback on the simulator. @param callback: A callback function without arguments. """ self.add_timeout(0, callback) # Wake up the main loop with an ignored message. os.write(self._pipe_wr, ' ') def wait_for_condition(self, condition, timeout=None): """Blocks until the condition is met or timeout is exceeded. This method should be called from a different thread while the simulator thread is running as it blocks the calling thread's execution until a condition is met. The condition function is evaluated in a callback running on the simulator thread and thus can safely access objects owned by the simulator. @param condition: A function called on the simulator thread that returns a value indicating if the condition is met. @param timeout: The timeout in seconds. None for no timeout. @return: The value returned by condition the last time it was called. This means that in the event of a timeout, this function will return a value that evaluates to False since the condition wasn't met the last time it was checked. """ # Lock and Condition used to wait until the passed condition is met. lock_cond = threading.Lock() cond_var = threading.Condition(lock_cond) # We use a mutable object like the [] to pass the reference by value # to the simulator's callback and let it modify the contents. ret = [None] # Create the actual callback that will be running on the simulator # thread and pass a reference to it to keep including it callback = lambda: self._condition_poller( callback, ret, cond_var, condition) # Let the simulator keep calling our function, it will keep calling # itself until the condition is met (or we remove it). self.run_on_simulator(callback) # Condition variable waiting loop. cur_time = time.time() start_time = cur_time with cond_var: while not ret[0]: if timeout is None: cond_var.wait() else: cur_timeout = timeout - (cur_time - start_time) if cur_timeout < 0: break cond_var.wait(cur_timeout) cur_time = time.time() self.remove_timeout(callback) return ret[0] def _condition_poller(self, callback, ref_value, cond_var, func): """Callback function used to poll for a condition. This method keeps scheduling itself in the simulator until the passed condition evaluates to a True value. This effectivelly implements a polling mechanism. See wait_for_condition() for details. """ with cond_var: ref_value[0] = func() if ref_value[0]: cond_var.notify() else: self.add_timeout(1., callback) def run(self): """Runs the simulation on the thread, called by start(). This method wraps the Simulator.run() to pass the timeout value passed during construction. """ try: Simulator.run(self, self._timeout) except Exception, e: self.error = e exc_type, exc_value, exc_traceback = sys.exc_info() self.traceback = ''.join(traceback.format_exception( exc_type, exc_value, exc_traceback))