• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A utility for "twisting" packets on a tun/tap interface.
15
16TunTwister and TapTwister echo packets on a tun/tap while swapping the source
17and destination at the ethernet and IP layers. This allows sockets to
18effectively loop back packets through the full networking stack, avoiding any
19shortcuts the kernel may take for actual IP loopback. Additionally, users can
20inspect each packet to assert testing invariants.
21"""
22
23import os
24import select
25import threading
26from scapy import all as scapy
27
28
29class TunTwister(object):
30  """TunTwister transports traffic travelling twixt two terminals.
31
32  TunTwister is a context manager that will read packets from a tun file
33  descriptor, swap the source and dest of the IP header, and write them back.
34  To use this class, tests also need to set up routing so that packets will be
35  routed to the tun interface.
36
37  Two sockets can communicate with each other through a TunTwister as if they
38  were each connecting to a remote endpoint. Both sockets will have the
39  perspective that the address of the other is a remote address.
40
41  Packet inspection can be done with a validator function. This can be any
42  function that takes a scapy packet object as its only argument. Exceptions
43  raised by your validator function will be re-raised on the main thread to fail
44  your tests.
45
46  NOTE: Exceptions raised by a validator function will supercede exceptions
47  raised in the context.
48
49  EXAMPLE:
50    def testFeatureFoo(self):
51      my_tun = MakeTunInterface()
52      # Set up routing so packets go to my_tun.
53
54      def ValidatePortNumber(packet):
55        self.assertEqual(8080, packet.getlayer(scapy.UDP).sport)
56        self.assertEqual(8080, packet.getlayer(scapy.UDP).dport)
57
58      with TunTwister(tun_fd=my_tun, validator=ValidatePortNumber):
59        sock = socket(AF_INET, SOCK_DGRAM, 0)
60        sock.bind(("0.0.0.0", 8080))
61        sock.settimeout(1.0)
62        sock.sendto("hello", ("1.2.3.4", 8080))
63        data, addr = sock.recvfrom(1024)
64        self.assertEqual("hello", data)
65        self.assertEqual(("1.2.3.4", 8080), addr)
66  """
67
68  # Hopefully larger than any packet.
69  _READ_BUF_SIZE = 2048
70  _POLL_TIMEOUT_SEC = 2.0
71  _POLL_FAST_TIMEOUT_MS = 100
72
73  def __init__(self, fd=None, validator=None):
74    """Construct a TunTwister.
75
76    The TunTwister will listen on the given TUN fd.
77    The validator is called for each packet *before* twisting. The packet is
78    passed in as a scapy packet object, and is the only argument passed to the
79    validator.
80
81    Args:
82      fd: File descriptor of a TUN interface.
83      validator: Function taking one scapy packet object argument.
84    """
85    self._fd = fd
86    # Use a pipe to signal the thread to exit.
87    self._signal_read, self._signal_write = os.pipe()
88    self._thread = threading.Thread(target=self._RunLoop, name="TunTwister")
89    self._validator = validator
90    self._error = None
91
92  def __enter__(self):
93    self._thread.start()
94
95  def __exit__(self, *args):
96    # Signal thread exit.
97    os.write(self._signal_write, "bye")
98    os.close(self._signal_write)
99    self._thread.join(TunTwister._POLL_TIMEOUT_SEC)
100    os.close(self._signal_read)
101    if self._thread.isAlive():
102      raise RuntimeError("Timed out waiting for thread exit")
103    # Re-raise any error thrown from our thread.
104    if isinstance(self._error, Exception):
105      raise self._error  # pylint: disable=raising-bad-type
106
107  def _RunLoop(self):
108    """Twist packets until exit signal."""
109    try:
110      while True:
111        read_fds, _, _ = select.select([self._fd, self._signal_read], [], [],
112                                       TunTwister._POLL_TIMEOUT_SEC)
113        if self._signal_read in read_fds:
114          self._Flush()
115          return
116        if self._fd in read_fds:
117          self._ProcessPacket()
118    except Exception as e:  # pylint: disable=broad-except
119      self._error = e
120
121  def _Flush(self):
122    """Ensure no packets are left in the buffer."""
123    p = select.poll()
124    p.register(self._fd, select.POLLIN)
125    while p.poll(TunTwister._POLL_FAST_TIMEOUT_MS):
126      self._ProcessPacket()
127
128  def _ProcessPacket(self):
129    """Read, twist, and write one packet on the tun/tap."""
130    # TODO: Handle EAGAIN "errors".
131    bytes_in = os.read(self._fd, TunTwister._READ_BUF_SIZE)
132    packet = self.DecodePacket(bytes_in)
133    # the user may wish to filter certain packets, such as
134    # Ethernet multicast packets
135    if self._DropPacket(packet):
136      return
137
138    if self._validator:
139      self._validator(packet)
140    packet = self.TwistPacket(packet)
141    os.write(self._fd, packet.build())
142
143  def _DropPacket(self, packet):
144    """Determine whether to drop the provided packet by inspection"""
145    return False
146
147  @classmethod
148  def DecodePacket(cls, bytes_in):
149    """Decode a byte array into a scapy object."""
150    return cls._DecodeIpPacket(bytes_in)
151
152  @classmethod
153  def TwistPacket(cls, packet):
154    """Swap the src and dst in the IP header."""
155    ip_type = type(packet)
156    if ip_type not in (scapy.IP, scapy.IPv6):
157      raise TypeError("Expected an IPv4 or IPv6 packet.")
158    packet.src, packet.dst = packet.dst, packet.src
159    packet = ip_type(packet.build())  # Fix the IP checksum.
160    return packet
161
162  @staticmethod
163  def _DecodeIpPacket(packet_bytes):
164    """Decode 'packet_bytes' as an IPv4 or IPv6 scapy object."""
165    ip_ver = (ord(packet_bytes[0]) & 0xF0) >> 4
166    if ip_ver == 4:
167      return scapy.IP(packet_bytes)
168    elif ip_ver == 6:
169      return scapy.IPv6(packet_bytes)
170    else:
171      raise ValueError("packet_bytes is not a valid IPv4 or IPv6 packet")
172
173
174class TapTwister(TunTwister):
175  """Test util for tap interfaces.
176
177  TapTwister works just like TunTwister, except it operates on tap interfaces
178  instead of tuns. Ethernet headers will have their sources and destinations
179  swapped in addition to IP headers.
180  """
181
182  @staticmethod
183  def _IsMulticastPacket(eth_pkt):
184    return int(eth_pkt.dst.split(":")[0], 16) & 0x1
185
186  def __init__(self, fd=None, validator=None, drop_multicast=True):
187    """Construct a TapTwister.
188
189    TapTwister works just like TunTwister, but handles both ethernet and IP
190    headers.
191
192    Args:
193      fd: File descriptor of a TAP interface.
194      validator: Function taking one scapy packet object argument.
195      drop_multicast: Drop Ethernet multicast packets
196    """
197    super(TapTwister, self).__init__(fd=fd, validator=validator)
198    self._drop_multicast = drop_multicast
199
200  def _DropPacket(self, packet):
201    return self._drop_multicast and self._IsMulticastPacket(packet)
202
203  @classmethod
204  def DecodePacket(cls, bytes_in):
205    return scapy.Ether(bytes_in)
206
207  @classmethod
208  def TwistPacket(cls, packet):
209    """Swap the src and dst in the ethernet and IP headers."""
210    packet.src, packet.dst = packet.dst, packet.src
211    ip_layer = packet.payload
212    twisted_ip_layer = super(TapTwister, cls).TwistPacket(ip_layer)
213    packet.payload = twisted_ip_layer
214    return packet
215