1# Copyright (c) 2013 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import dpkt 6import os 7import select 8import struct 9import sys 10import threading 11import time 12import traceback 13 14 15class SimulatorError(Exception): 16 "A Simulator generic error." 17 18 19class NullContext(object): 20 """A context manager without any functionality.""" 21 def __enter__(self): 22 return self 23 24 25 def __exit__(self, exc_type, exc_val, exc_tb): 26 return False # raises the exception if passed. 27 28 29class Simulator(object): 30 """A TUN/TAP network interface simulator class. 31 32 This class allows several implementations of different fake hosts to 33 coexists on the same TUN/TAP interface. It will dispatch the same packet 34 to each one of the registered hosts, providing some basic filtering 35 to simplify these implementations. 36 """ 37 38 def __init__(self, iface): 39 """Initialize the instance. 40 41 @param tuntap.TunTap iface: the interface over which this interface 42 runs. Should not be shared with other modules. 43 """ 44 self._iface = iface 45 self._rules = [] 46 # _events holds a lists of events that need to be fired for each 47 # timestamp stored on the key. The event list is a list of callback 48 # functions that will be called if the simulation reaches that 49 # timestamp. This is used to fire time-based events. 50 self._events = {} 51 self._write_queue = [] 52 # A pipe used to wake up the run() method from a diffent thread calling 53 # stop(). See the stop() method for details. 54 self._pipe_rd, self._pipe_wr = os.pipe() 55 self._running = False 56 # Lock object used for _events if multithreading is required. 57 self._lock = NullContext() 58 59 60 def __del__(self): 61 os.close(self._pipe_rd) 62 os.close(self._pipe_wr) 63 64 65 def add_match(self, rule, callback): 66 """Add a new match rule to the outbound traffic. 67 68 This function adds a new rule that will be matched against each packet 69 that the host sends through the interface and will call a callback if 70 it matches. The rule can be specified in the following ways: 71 * A python function that takes a packet as a single argument and 72 returns True when the packet matches. 73 * A dictionary of key=value pairs that all of them need to be matched. 74 A pair matches when the packet has the provided chain of attributes 75 and its value is equal to the provided value. For example, this will 76 match any DNS traffic sent to the host 192.168.0.1: 77 {"ip.dst": socket.inet_aton("192.168.0.1"), 78 "ip.upd.dport": 53} 79 80 @param rule: The rule description. 81 @param callback: A callback function that receives the dpkt packet as 82 the only argument. 83 """ 84 if not callable(callback): 85 raise SimulatorError("|callback| must be a callable object.") 86 87 if callable(rule): 88 self._rules.append((rule, callback)) 89 if isinstance(rule, dict): 90 rule = dict(rule) # Makes a copy of the dict, but not the contents. 91 self._rules.append((lambda p: self._dict_rule(rule, p), callback)) 92 else: 93 raise SimulatorError("Unknown rule format: %r" % rule) 94 95 96 def add_timeout(self, timeout, callback): 97 """Add a new callback function to be called after a timeout. 98 99 This method schedules the given |callback| to be called after |timeout| 100 seconds. The callback will be called at most once while the simulator 101 is running (see the run() method). To have a repetitive event call again 102 add_timeout() from the callback. 103 104 @param timeout: The rule description. 105 @param callback: A callback function that doesn't receive any argument. 106 """ 107 if not callable(callback): 108 raise SimulatorError("|callback| must be a callable object.") 109 timestamp = time.time() + timeout 110 with self._lock: 111 if timestamp not in self._events: 112 self._events[timestamp] = [callback] 113 else: 114 self._events[timestamp].append(callback) 115 116 117 def remove_timeout(self, callback): 118 """Removes the every scheduled timeout call to the passed callback. 119 120 When a callable object is passed to add_timeout() it is scheduled to be 121 called once the timeout is reached. This method removes all the 122 scheduled calls to that object. 123 124 @param callback: The callable object passed to add_timeout(). 125 @return: Wether the callback was found and removed at least once. 126 """ 127 removed = False 128 for _ts, ev_list in self._events.iteritems(): 129 try: 130 while True: 131 ev_list.remove(callback) 132 removed = True 133 except ValueError: 134 pass 135 return removed 136 137 138 def _dict_rule(self, rules, pkt): 139 """Returns wether a given packet matches a set of rules. 140 141 The maching rules passed in |rules| need to be a dict() as described 142 on the add_match() method. The packet |pkt| is any dpkt packet. 143 """ 144 for key, value in rules.iteritems(): 145 p = pkt 146 for member in key.split('.'): 147 if not hasattr(p, member): 148 return False 149 p = getattr(p, member) 150 if p != value: 151 return False 152 return True 153 154 155 def write(self, pkt): 156 """Writes a packet to the network interface. 157 158 @param pkt: The dpkt.Packet to be received on the network interface. 159 """ 160 # Converts the dpkt packet to: flags, proto, buffer. 161 self._write_queue.append(struct.pack("!HH", 0, pkt.type) + str(pkt)) 162 163 164 def run(self, timeout=None, until=None): 165 """Runs the Simulator. 166 167 This method blocks the caller thread until the timeout is reached (if 168 a timeout is passed), until stop() is called or until the function 169 passed in until returns a True value (if a function is passed); 170 whichever occurs first. stop() can be called from any other thread or 171 from a callback called from this thread. 172 173 @param timeout: The timeout in seconds. Can be a float value, or None 174 for no timeout. 175 @param until: A callable object called during the loop returning True 176 when the loop should stop. 177 """ 178 if not self._iface.is_up(): 179 raise SimulatorError("Interface is down.") 180 181 stop_callback = None 182 if timeout != None: 183 # We use a newly created callable object to avoid remove another 184 # scheduled call to self.stop. 185 stop_callback = lambda: self.stop() 186 self.add_timeout(timeout, stop_callback) 187 188 self._running = True 189 iface_fd = self._iface.fileno() 190 # Check the until function. 191 while not (until and until()): 192 # The main purpose of this loop is to wait (block) until the next 193 # event is required to be fired. There are four kinds of events: 194 # * a packet is received. 195 # * a packet waiting to be sent can now be sent. 196 # * a time-based event needs to be fired. 197 # * the simulator was stopped from a different thread. 198 # To achieve this we use select.select() to wait simultaneously on 199 # all those event sources. 200 201 # Fires all the time-based events that need to be fired and computes 202 # the timeout for the next event if there's one. 203 timeout = None 204 cur_time = time.time() 205 with self._lock: 206 if self._events: 207 # Check events that should be fired. 208 while self._events and min(self._events) <= cur_time: 209 key = min(self._events) 210 lst = self._events[key] 211 del self._events[key] 212 for callback in lst: 213 callback() 214 cur_time = time.time() 215 # Check if there is an event to attend. Here we know that 216 # min(self._events) > cur_time because the previous while 217 # finished. 218 if self._events: 219 timeout = min(self._events) - cur_time # in seconds 220 221 # Pool the until() function at least once a second. 222 if timeout is None or timeout > 1.0: 223 timeout = 1.0 224 225 # Compute the list of file descriptors that select.select() needs to 226 # monitor to attend the required events. select() will return when 227 # any of the following occurs: 228 # * rlist: is possible to read from the interface or another 229 # thread want's to wake up the simulator loop. 230 # * wlist: is possible to write to network, if there's a packet 231 # pending. 232 # * xlist: an error on the network fd occured. Likely the TAP 233 # interface was closed. 234 # * timeout: The previously computed timeout was reached. 235 rlist = iface_fd, self._pipe_rd 236 wlist = tuple() 237 if self._write_queue: 238 wlist = iface_fd, 239 xlist = iface_fd, 240 241 rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout) 242 243 if self._pipe_rd in rlist: 244 msg = os.read(self._pipe_rd, 1) 245 # stop() breaks the loop sending a '*'. 246 if '*' in msg: 247 break 248 # Other messages are ignored. 249 250 if xlist: 251 break 252 253 if iface_fd in wlist: 254 self._iface.write(self._write_queue.pop(0)) 255 # Attempt to send all the scheduled packets before reading more 256 continue 257 258 # Process the given packet: 259 if iface_fd in rlist: 260 raw = self._iface.read() 261 flag, proto = struct.unpack("!HH", raw[:4]) 262 pkt = dpkt.ethernet.Ethernet(raw[4:]) 263 for rule, callback in self._rules: 264 if rule(pkt): 265 # Parse again the packet to allow callbacks to modify 266 # it. 267 callback(dpkt.ethernet.Ethernet(raw[4:])) 268 269 if stop_callback: 270 self.remove_timeout(stop_callback) 271 self._running = False 272 273 274 def stop(self): 275 """Stops the run() method if it is running.""" 276 os.write(self._pipe_wr, '*') 277 278 279class SimulatorThread(threading.Thread, Simulator): 280 """A threaded version of the Simulator. 281 282 This class exposses a similar interface as the Simulator class with the 283 difference that it runs on its own thread. This exposes an extra method 284 start() that should be called instead of Simulator.run(). start() will make 285 the process run continuosly until stop() is called, after which the 286 simulator can't be restarted. 287 288 The methods used to add new matches can be called from any thread *before* 289 the method start() is caller. After that point, only the callbacks, running 290 from this thread, are allowed to create new matches and timeouts. 291 292 Example: 293 simu = SimulatorThread(tap_interface) 294 simu.add_match({"ip.tcp.dport": 80}, some_callback) 295 simu.start() 296 time.sleep(100) 297 simu.stop() 298 simu.join() # Optional 299 """ 300 301 def __init__(self, iface, timeout=None): 302 threading.Thread.__init__(self) 303 Simulator.__init__(self, iface) 304 self._timeout = timeout 305 # We allow the same thread to acquire the lock more than once. This is 306 # useful if a callback want's to add itself. 307 self._lock = threading.RLock() 308 self.error = None 309 310 311 def run_on_simulator(self, callback): 312 """Runs the given callback on the SimulatorThread thread. 313 314 Before calling start() on the SimulatorThread, all the calls seting up 315 the simulator are allowed, but once the thread is running, concurrency 316 problems should be considered. This method runs the provided callback 317 on the simulator. 318 319 @param callback: A callback function without arguments. 320 """ 321 self.add_timeout(0, callback) 322 # Wake up the main loop with an ignored message. 323 os.write(self._pipe_wr, ' ') 324 325 326 def wait_for_condition(self, condition, timeout=None): 327 """Blocks until the condition is met or timeout is exceeded. 328 329 This method should be called from a different thread while the simulator 330 thread is running as it blocks the calling thread's execution until a 331 condition is met. The condition function is evaluated in a callback 332 running on the simulator thread and thus can safely access objects owned 333 by the simulator. 334 335 @param condition: A function called on the simulator thread that returns 336 a value indicating if the condition is met. 337 @param timeout: The timeout in seconds. None for no timeout. 338 @return: The value returned by condition the last time it was called. 339 This means that in the event of a timeout, this function will return a 340 value that evaluates to False since the condition wasn't met the last 341 time it was checked. 342 """ 343 # Lock and Condition used to wait until the passed condition is met. 344 lock_cond = threading.Lock() 345 cond_var = threading.Condition(lock_cond) 346 # We use a mutable object like the [] to pass the reference by value 347 # to the simulator's callback and let it modify the contents. 348 ret = [None] 349 350 # Create the actual callback that will be running on the simulator 351 # thread and pass a reference to it to keep including it 352 callback = lambda: self._condition_poller( 353 callback, ret, cond_var, condition) 354 355 # Let the simulator keep calling our function, it will keep calling 356 # itself until the condition is met (or we remove it). 357 self.run_on_simulator(callback) 358 359 # Condition variable waiting loop. 360 cur_time = time.time() 361 start_time = cur_time 362 with cond_var: 363 while not ret[0]: 364 if timeout is None: 365 cond_var.wait() 366 else: 367 cur_timeout = timeout - (cur_time - start_time) 368 if cur_timeout < 0: 369 break 370 cond_var.wait(cur_timeout) 371 cur_time = time.time() 372 self.remove_timeout(callback) 373 374 return ret[0] 375 376 377 def _condition_poller(self, callback, ref_value, cond_var, func): 378 """Callback function used to poll for a condition. 379 380 This method keeps scheduling itself in the simulator until the passed 381 condition evaluates to a True value. This effectivelly implements a 382 polling mechanism. See wait_for_condition() for details. 383 """ 384 with cond_var: 385 ref_value[0] = func() 386 if ref_value[0]: 387 cond_var.notify() 388 else: 389 self.add_timeout(1., callback) 390 391 392 def run(self): 393 """Runs the simulation on the thread, called by start(). 394 395 This method wraps the Simulator.run() to pass the timeout value passed 396 during construction. 397 """ 398 try: 399 Simulator.run(self, self._timeout) 400 except Exception, e: 401 self.error = e 402 exc_type, exc_value, exc_traceback = sys.exc_info() 403 self.traceback = ''.join(traceback.format_exception( 404 exc_type, exc_value, exc_traceback)) 405