• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import os
20import socket
21import struct
22import subprocess
23import tempfile
24import unittest
25
26from bpf import *  # pylint: disable=wildcard-import
27import csocket
28import net_test
29import sock_diag
30
31libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
32HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0)
33HAVE_EBPF_SOCKET = net_test.LINUX_VERSION >= (4, 14, 0)
34KEY_SIZE = 8
35VALUE_SIZE = 4
36TOTAL_ENTRIES = 20
37TEST_UID = 54321
38TEST_GID = 12345
39# Offset to store the map key in stack register REG10
40key_offset = -8
41# Offset to store the map value in stack register REG10
42value_offset = -16
43
44# Debug usage only.
45def PrintMapInfo(map_fd):
46  # A random key that the map does not contain.
47  key = 10086
48  while 1:
49    try:
50      nextKey = GetNextKey(map_fd, key).value
51      value = LookupMap(map_fd, nextKey)
52      print repr(nextKey) + " : " + repr(value.value)
53      key = nextKey
54    except:
55      print "no value"
56      break
57
58
59# A dummy loopback function that causes a socket to send traffic to itself.
60def SocketUDPLoopBack(packet_count, version, prog_fd):
61  family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
62  sock = socket.socket(family, socket.SOCK_DGRAM, 0)
63  if prog_fd is not None:
64    BpfProgAttachSocket(sock.fileno(), prog_fd)
65  net_test.SetNonBlocking(sock)
66  addr = {4: "127.0.0.1", 6: "::1"}[version]
67  sock.bind((addr, 0))
68  addr = sock.getsockname()
69  sockaddr = csocket.Sockaddr(addr)
70  for i in xrange(packet_count):
71    sock.sendto("foo", addr)
72    data, retaddr = csocket.Recvfrom(sock, 4096, 0)
73    assert "foo" == data
74    assert sockaddr == retaddr
75  return sock
76
77
78# The main code block for eBPF packet counting program. It takes a preloaded
79# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
80# exist in the map yet, the program will update the map with a new <key, 1>
81# pair. Otherwise it will jump to next code block to handle it.
82# REG0: regiter storing return value from helper function and the final return
83# value of eBPF program.
84# REG1 - REG5: temporary register used for storing values and load parameters
85# into eBPF helper function. After calling helper function, the value for these
86# registers will be reset.
87# REG6 - REG9: registers store values that will not be cleared when calling
88# eBPF helper function.
89# REG10: A stack stores values need to be accessed by the address. Program can
90# retrieve the address of a value by specifying the position of the value in
91# the stack.
92def BpfFuncCountPacketInit(map_fd):
93  key_pos = BPF_REG_7
94  insPackCountStart = [
95      # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
96      BpfMov64Reg(key_pos, BPF_REG_10),
97      BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
98      # Load map fd and look up the key in the map
99      BpfLoadMapFd(map_fd, BPF_REG_1),
100      BpfMov64Reg(BPF_REG_2, key_pos),
101      BpfFuncCall(BPF_FUNC_map_lookup_elem),
102      # if the map element already exist, jump out of this
103      # code block and let next part to handle it
104      BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
105      BpfLoadMapFd(map_fd, BPF_REG_1),
106      BpfMov64Reg(BPF_REG_2, key_pos),
107      # Initial a new <key, value> pair with value equal to 1 and update to map
108      BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
109      BpfMov64Reg(BPF_REG_3, BPF_REG_10),
110      BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
111      BpfMov64Imm(BPF_REG_4, 0),
112      BpfFuncCall(BPF_FUNC_map_update_elem)
113  ]
114  return insPackCountStart
115
116
117INS_BPF_EXIT_BLOCK = [
118    BpfMov64Imm(BPF_REG_0, 0),
119    BpfExitInsn()
120]
121
122# Bpf instruction for cgroup bpf filter to accept a packet and exit.
123INS_CGROUP_ACCEPT = [
124    # Set return value to 1 and exit.
125    BpfMov64Imm(BPF_REG_0, 1),
126    BpfExitInsn()
127]
128
129# Bpf instruction for socket bpf filter to accept a packet and exit.
130INS_SK_FILTER_ACCEPT = [
131    # Precondition: BPF_REG_6 = sk_buff context
132    # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
133    # return value.
134    BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
135    BpfExitInsn()
136]
137
138# Update a existing map element with +1.
139INS_PACK_COUNT_UPDATE = [
140    # Precondition: BPF_REG_0 = Value retrieved from BPF maps
141    # Add one to the corresponding eBPF value field for a specific eBPF key.
142    BpfMov64Reg(BPF_REG_2, BPF_REG_0),
143    BpfMov64Imm(BPF_REG_1, 1),
144    BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
145]
146
147INS_BPF_PARAM_STORE = [
148    BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
149]
150
151@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
152                     "BPF helper function is not fully supported")
153class BpfTest(net_test.NetworkTest):
154
155  def setUp(self):
156    self.map_fd = -1
157    self.prog_fd = -1
158    self.sock = None
159
160  def tearDown(self):
161    if self.prog_fd >= 0:
162      os.close(self.prog_fd)
163    if self.map_fd >= 0:
164      os.close(self.map_fd)
165    if self.sock:
166      self.sock.close()
167
168  def testCreateMap(self):
169    key, value = 1, 1
170    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
171                            TOTAL_ENTRIES)
172    UpdateMap(self.map_fd, key, value)
173    self.assertEquals(value, LookupMap(self.map_fd, key).value)
174    DeleteMap(self.map_fd, key)
175    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
176
177  def CheckAllMapEntry(self, nonexistent_key, totalEntries, value):
178    count = 0
179    key = nonexistent_key
180    while True:
181      if count == totalEntries:
182        self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
183        break
184      else:
185        result = GetNextKey(self.map_fd, key)
186        key = result.value
187        self.assertGreaterEqual(key, 0)
188        self.assertEquals(value, LookupMap(self.map_fd, key).value)
189        count += 1
190
191  def testIterateMap(self):
192    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
193                            TOTAL_ENTRIES)
194    value = 1024
195    for key in xrange(0, TOTAL_ENTRIES):
196      UpdateMap(self.map_fd, key, value)
197    for key in xrange(0, TOTAL_ENTRIES):
198      self.assertEquals(value, LookupMap(self.map_fd, key).value)
199    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
200    nonexistent_key = -1
201    self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value)
202
203  def testFindFirstMapKey(self):
204    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
205                            TOTAL_ENTRIES)
206    value = 1024
207    for key in xrange(0, TOTAL_ENTRIES):
208      UpdateMap(self.map_fd, key, value)
209    firstKey = GetFirstKey(self.map_fd)
210    key = firstKey.value
211    self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value)
212
213
214  def testRdOnlyMap(self):
215    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
216                            TOTAL_ENTRIES, map_flags=BPF_F_RDONLY)
217    value = 1024
218    key = 1
219    self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value)
220    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
221
222  @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
223                       "BPF helper function is not fully supported")
224  def testWrOnlyMap(self):
225    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
226                            TOTAL_ENTRIES, map_flags=BPF_F_WRONLY)
227    value = 1024
228    key = 1
229    UpdateMap(self.map_fd, key, value)
230    self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key)
231
232  def testProgLoad(self):
233    # Move skb to BPF_REG_6 for further usage
234    instructions = [
235        BpfMov64Reg(BPF_REG_6, BPF_REG_1)
236    ]
237    instructions += INS_SK_FILTER_ACCEPT
238    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
239    SocketUDPLoopBack(1, 4, self.prog_fd)
240    SocketUDPLoopBack(1, 6, self.prog_fd)
241
242  def testPacketBlock(self):
243    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
244    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
245    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
246
247  def testPacketCount(self):
248    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
249                            TOTAL_ENTRIES)
250    key = 0xf0f0
251    # Set up instruction block with key loaded at BPF_REG_0.
252    instructions = [
253        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
254        BpfMov64Imm(BPF_REG_0, key)
255    ]
256    # Concatenate the generic packet count bpf program to it.
257    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
258                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
259                     + INS_SK_FILTER_ACCEPT)
260    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
261    packet_count = 10
262    SocketUDPLoopBack(packet_count, 4, self.prog_fd)
263    SocketUDPLoopBack(packet_count, 6, self.prog_fd)
264    self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value)
265
266  @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
267                       "BPF helper function is not fully supported")
268  def testGetSocketCookie(self):
269    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
270                            TOTAL_ENTRIES)
271    # Move skb to REG6 for further usage, call helper function to get socket
272    # cookie of current skb and return the cookie at REG0 for next code block
273    instructions = [
274        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
275        BpfFuncCall(BPF_FUNC_get_socket_cookie)
276    ]
277    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
278                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
279                     + INS_SK_FILTER_ACCEPT)
280    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
281    packet_count = 10
282    def PacketCountByCookie(version):
283      self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
284      cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
285      self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value)
286      self.sock.close()
287    PacketCountByCookie(4)
288    PacketCountByCookie(6)
289
290  @unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
291                       "BPF helper function is not fully supported")
292  def testGetSocketUid(self):
293    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
294                            TOTAL_ENTRIES)
295    # Set up the instruction with uid at BPF_REG_0.
296    instructions = [
297        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
298        BpfFuncCall(BPF_FUNC_get_socket_uid)
299    ]
300    # Concatenate the generic packet count bpf program to it.
301    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
302                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
303                     + INS_SK_FILTER_ACCEPT)
304    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
305    packet_count = 10
306    uid = TEST_UID
307    with net_test.RunAsUid(uid):
308      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
309      SocketUDPLoopBack(packet_count, 4, self.prog_fd)
310      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
311      DeleteMap(self.map_fd, uid);
312      SocketUDPLoopBack(packet_count, 6, self.prog_fd)
313      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
314
315@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
316                     "Cgroup BPF is not fully supported")
317class BpfCgroupTest(net_test.NetworkTest):
318
319  @classmethod
320  def setUpClass(cls):
321    cls._cg_dir = tempfile.mkdtemp(prefix="cg_bpf-")
322    cmd = "mount -t cgroup2 cg_bpf %s" % cls._cg_dir
323    try:
324      subprocess.check_call(cmd.split())
325    except subprocess.CalledProcessError:
326      # If an exception is thrown in setUpClass, the test fails and
327      # tearDownClass is not called.
328      os.rmdir(cls._cg_dir)
329      raise
330    cls._cg_fd = os.open(cls._cg_dir, os.O_DIRECTORY | os.O_RDONLY)
331
332  @classmethod
333  def tearDownClass(cls):
334    os.close(cls._cg_fd)
335    subprocess.call(('umount %s' % cls._cg_dir).split())
336    os.rmdir(cls._cg_dir)
337
338  def setUp(self):
339    self.prog_fd = -1
340    self.map_fd = -1
341
342  def tearDown(self):
343    if self.prog_fd >= 0:
344      os.close(self.prog_fd)
345    if self.map_fd >= 0:
346      os.close(self.map_fd)
347    try:
348      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
349    except socket.error:
350      pass
351    try:
352      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
353    except socket.error:
354      pass
355    try:
356      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
357    except socket.error:
358      pass
359
360  def testCgroupBpfAttach(self):
361    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
362    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
363    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
364
365  def testCgroupIngress(self):
366    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
367    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
368    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
369    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
370    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
371    SocketUDPLoopBack(1, 4, None)
372    SocketUDPLoopBack(1, 6, None)
373
374  def testCgroupEgress(self):
375    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
376    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
377    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
378    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
379    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
380    SocketUDPLoopBack( 1, 4, None)
381    SocketUDPLoopBack( 1, 6, None)
382
383  def testCgroupBpfUid(self):
384    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
385                            TOTAL_ENTRIES)
386    # Similar to the program used in testGetSocketUid.
387    instructions = [
388        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
389        BpfFuncCall(BPF_FUNC_get_socket_uid)
390    ]
391    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
392                     + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT)
393    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
394    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
395    packet_count = 20
396    uid = TEST_UID
397    with net_test.RunAsUid(uid):
398      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
399      SocketUDPLoopBack(packet_count, 4, None)
400      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
401      DeleteMap(self.map_fd, uid)
402      SocketUDPLoopBack(packet_count, 6, None)
403      self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value)
404    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
405
406  def checkSocketCreate(self, family, socktype, success):
407    try:
408      sock = socket.socket(family, socktype, 0)
409      sock.close()
410    except socket.error, e:
411      if success:
412        self.fail("Failed to create socket family=%d type=%d err=%s" %
413                  (family, socktype, os.strerror(e.errno)))
414      return;
415    if not success:
416      self.fail("unexpected socket family=%d type=%d created, should be blocked" %
417                (family, socktype))
418
419
420  def trySocketCreate(self, success):
421      for family in [socket.AF_INET, socket.AF_INET6]:
422        for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
423          self.checkSocketCreate(family, socktype, success)
424
425  @unittest.skipUnless(HAVE_EBPF_SOCKET,
426                     "Cgroup BPF socket is not supported")
427  def testCgroupSocketCreateBlock(self):
428    instructions = [
429        BpfFuncCall(BPF_FUNC_get_current_uid_gid),
430        BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
431        BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
432    ]
433    instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT;
434    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
435    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
436    with net_test.RunAsUid(TEST_UID):
437      # Socket creation with target uid should fail
438      self.trySocketCreate(False);
439    # Socket create with different uid should success
440    self.trySocketCreate(True)
441    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
442    with net_test.RunAsUid(TEST_UID):
443      self.trySocketCreate(True)
444
445if __name__ == "__main__":
446  unittest.main()
447