1# Copyright 2017 The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""A utility for "twisting" packets on a tun/tap interface. 15 16TunTwister and TapTwister echo packets on a tun/tap while swapping the source 17and destination at the ethernet and IP layers. This allows sockets to 18effectively loop back packets through the full networking stack, avoiding any 19shortcuts the kernel may take for actual IP loopback. Additionally, users can 20inspect each packet to assert testing invariants. 21""" 22 23import os 24import select 25import threading 26from scapy import all as scapy 27 28 29class TunTwister(object): 30 """TunTwister transports traffic travelling twixt two terminals. 31 32 TunTwister is a context manager that will read packets from a tun file 33 descriptor, swap the source and dest of the IP header, and write them back. 34 To use this class, tests also need to set up routing so that packets will be 35 routed to the tun interface. 36 37 Two sockets can communicate with each other through a TunTwister as if they 38 were each connecting to a remote endpoint. Both sockets will have the 39 perspective that the address of the other is a remote address. 40 41 Packet inspection can be done with a validator function. This can be any 42 function that takes a scapy packet object as its only argument. Exceptions 43 raised by your validator function will be re-raised on the main thread to fail 44 your tests. 45 46 NOTE: Exceptions raised by a validator function will supercede exceptions 47 raised in the context. 48 49 EXAMPLE: 50 def testFeatureFoo(self): 51 my_tun = MakeTunInterface() 52 # Set up routing so packets go to my_tun. 53 54 def ValidatePortNumber(packet): 55 self.assertEqual(8080, packet.getlayer(scapy.UDP).sport) 56 self.assertEqual(8080, packet.getlayer(scapy.UDP).dport) 57 58 with TunTwister(tun_fd=my_tun, validator=ValidatePortNumber): 59 sock = socket(AF_INET, SOCK_DGRAM, 0) 60 sock.bind(("0.0.0.0", 8080)) 61 sock.settimeout(1.0) 62 sock.sendto("hello", ("1.2.3.4", 8080)) 63 data, addr = sock.recvfrom(1024) 64 self.assertEqual("hello", data) 65 self.assertEqual(("1.2.3.4", 8080), addr) 66 """ 67 68 # Hopefully larger than any packet. 69 _READ_BUF_SIZE = 2048 70 _POLL_TIMEOUT_SEC = 2.0 71 _POLL_FAST_TIMEOUT_MS = 100 72 73 def __init__(self, fd=None, validator=None): 74 """Construct a TunTwister. 75 76 The TunTwister will listen on the given TUN fd. 77 The validator is called for each packet *before* twisting. The packet is 78 passed in as a scapy packet object, and is the only argument passed to the 79 validator. 80 81 Args: 82 fd: File descriptor of a TUN interface. 83 validator: Function taking one scapy packet object argument. 84 """ 85 self._fd = fd 86 # Use a pipe to signal the thread to exit. 87 self._signal_read, self._signal_write = os.pipe() 88 self._thread = threading.Thread(target=self._RunLoop, name="TunTwister") 89 self._validator = validator 90 self._error = None 91 92 def __enter__(self): 93 self._thread.start() 94 95 def __exit__(self, *args): 96 # Signal thread exit. 97 os.write(self._signal_write, "bye") 98 os.close(self._signal_write) 99 self._thread.join(TunTwister._POLL_TIMEOUT_SEC) 100 os.close(self._signal_read) 101 if self._thread.isAlive(): 102 raise RuntimeError("Timed out waiting for thread exit") 103 # Re-raise any error thrown from our thread. 104 if isinstance(self._error, Exception): 105 raise self._error # pylint: disable=raising-bad-type 106 107 def _RunLoop(self): 108 """Twist packets until exit signal.""" 109 try: 110 while True: 111 read_fds, _, _ = select.select([self._fd, self._signal_read], [], [], 112 TunTwister._POLL_TIMEOUT_SEC) 113 if self._signal_read in read_fds: 114 self._Flush() 115 return 116 if self._fd in read_fds: 117 self._ProcessPacket() 118 except Exception as e: # pylint: disable=broad-except 119 self._error = e 120 121 def _Flush(self): 122 """Ensure no packets are left in the buffer.""" 123 p = select.poll() 124 p.register(self._fd, select.POLLIN) 125 while p.poll(TunTwister._POLL_FAST_TIMEOUT_MS): 126 self._ProcessPacket() 127 128 def _ProcessPacket(self): 129 """Read, twist, and write one packet on the tun/tap.""" 130 # TODO: Handle EAGAIN "errors". 131 bytes_in = os.read(self._fd, TunTwister._READ_BUF_SIZE) 132 packet = self.DecodePacket(bytes_in) 133 # the user may wish to filter certain packets, such as 134 # Ethernet multicast packets 135 if self._DropPacket(packet): 136 return 137 138 if self._validator: 139 self._validator(packet) 140 packet = self.TwistPacket(packet) 141 os.write(self._fd, packet.build()) 142 143 def _DropPacket(self, packet): 144 """Determine whether to drop the provided packet by inspection""" 145 return False 146 147 @classmethod 148 def DecodePacket(cls, bytes_in): 149 """Decode a byte array into a scapy object.""" 150 return cls._DecodeIpPacket(bytes_in) 151 152 @classmethod 153 def TwistPacket(cls, packet): 154 """Swap the src and dst in the IP header.""" 155 ip_type = type(packet) 156 if ip_type not in (scapy.IP, scapy.IPv6): 157 raise TypeError("Expected an IPv4 or IPv6 packet.") 158 packet.src, packet.dst = packet.dst, packet.src 159 packet = ip_type(packet.build()) # Fix the IP checksum. 160 return packet 161 162 @staticmethod 163 def _DecodeIpPacket(packet_bytes): 164 """Decode 'packet_bytes' as an IPv4 or IPv6 scapy object.""" 165 ip_ver = (ord(packet_bytes[0]) & 0xF0) >> 4 166 if ip_ver == 4: 167 return scapy.IP(packet_bytes) 168 elif ip_ver == 6: 169 return scapy.IPv6(packet_bytes) 170 else: 171 raise ValueError("packet_bytes is not a valid IPv4 or IPv6 packet") 172 173 174class TapTwister(TunTwister): 175 """Test util for tap interfaces. 176 177 TapTwister works just like TunTwister, except it operates on tap interfaces 178 instead of tuns. Ethernet headers will have their sources and destinations 179 swapped in addition to IP headers. 180 """ 181 182 @staticmethod 183 def _IsMulticastPacket(eth_pkt): 184 return int(eth_pkt.dst.split(":")[0], 16) & 0x1 185 186 def __init__(self, fd=None, validator=None, drop_multicast=True): 187 """Construct a TapTwister. 188 189 TapTwister works just like TunTwister, but handles both ethernet and IP 190 headers. 191 192 Args: 193 fd: File descriptor of a TAP interface. 194 validator: Function taking one scapy packet object argument. 195 drop_multicast: Drop Ethernet multicast packets 196 """ 197 super(TapTwister, self).__init__(fd=fd, validator=validator) 198 self._drop_multicast = drop_multicast 199 200 def _DropPacket(self, packet): 201 return self._drop_multicast and self._IsMulticastPacket(packet) 202 203 @classmethod 204 def DecodePacket(cls, bytes_in): 205 return scapy.Ether(bytes_in) 206 207 @classmethod 208 def TwistPacket(cls, packet): 209 """Swap the src and dst in the ethernet and IP headers.""" 210 packet.src, packet.dst = packet.dst, packet.src 211 ip_layer = packet.payload 212 twisted_ip_layer = super(TapTwister, cls).TwistPacket(ip_layer) 213 packet.payload = twisted_ip_layer 214 return packet 215