• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import os
20import socket
21import subprocess
22import tempfile
23import unittest
24
25import bpf
26from bpf import BPF_ADD
27from bpf import BPF_AND
28from bpf import BPF_CGROUP_INET_EGRESS
29from bpf import BPF_CGROUP_INET_INGRESS
30from bpf import BPF_CGROUP_INET_SOCK_CREATE
31from bpf import BPF_DW
32from bpf import BPF_F_RDONLY
33from bpf import BPF_F_WRONLY
34from bpf import BPF_FUNC_get_current_uid_gid
35from bpf import BPF_FUNC_get_socket_cookie
36from bpf import BPF_FUNC_get_socket_uid
37from bpf import BPF_FUNC_ktime_get_boot_ns
38from bpf import BPF_FUNC_ktime_get_ns
39from bpf import BPF_FUNC_map_lookup_elem
40from bpf import BPF_FUNC_map_update_elem
41from bpf import BPF_FUNC_skb_change_head
42from bpf import BPF_JNE
43from bpf import BPF_MAP_TYPE_HASH
44from bpf import BPF_PROG_TYPE_CGROUP_SKB
45from bpf import BPF_PROG_TYPE_CGROUP_SOCK
46from bpf import BPF_PROG_TYPE_SCHED_CLS
47from bpf import BPF_PROG_TYPE_SOCKET_FILTER
48from bpf import BPF_REG_0
49from bpf import BPF_REG_1
50from bpf import BPF_REG_10
51from bpf import BPF_REG_2
52from bpf import BPF_REG_3
53from bpf import BPF_REG_4
54from bpf import BPF_REG_6
55from bpf import BPF_REG_7
56from bpf import BPF_STX
57from bpf import BPF_W
58from bpf import BPF_XADD
59from bpf import BpfAlu64Imm
60from bpf import BpfExitInsn
61from bpf import BpfFuncCall
62from bpf import BpfJumpImm
63from bpf import BpfLdxMem
64from bpf import BpfLoadMapFd
65from bpf import BpfMov64Imm
66from bpf import BpfMov64Reg
67from bpf import BpfProgAttach
68from bpf import BpfProgAttachSocket
69from bpf import BpfProgDetach
70from bpf import BpfProgLoad
71from bpf import BpfRawInsn
72from bpf import BpfStMem
73from bpf import BpfStxMem
74from bpf import CreateMap
75from bpf import DeleteMap
76from bpf import GetFirstKey
77from bpf import GetNextKey
78from bpf import LookupMap
79from bpf import UpdateMap
80import csocket
81import net_test
82from net_test import LINUX_VERSION
83import sock_diag
84
85libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
86
87HAVE_EBPF_ACCOUNTING = bpf.HAVE_EBPF_4_9
88HAVE_EBPF_SOCKET = bpf.HAVE_EBPF_4_14
89
90# bpf_ktime_get_ns() was made non-GPL requiring in 5.8 and at the same time
91# bpf_ktime_get_boot_ns() was added, both of these changes were backported to
92# Android Common Kernel in 4.14.221, 4.19.175, 5.4.97.
93# As such we require 4.14.222+ 4.19.176+ 5.4.98+ 5.8.0+,
94# but since we only really care about LTS releases:
95HAVE_EBPF_KTIME_GET_NS_APACHE2 = (
96    ((LINUX_VERSION > (4, 14, 221)) and (LINUX_VERSION < (4, 19, 0))) or
97    ((LINUX_VERSION > (4, 19, 175)) and (LINUX_VERSION < (5, 4, 0))) or
98    (LINUX_VERSION > (5, 4, 97))
99)
100HAVE_EBPF_KTIME_GET_BOOT_NS = HAVE_EBPF_KTIME_GET_NS_APACHE2
101
102KEY_SIZE = 8
103VALUE_SIZE = 4
104TOTAL_ENTRIES = 20
105TEST_UID = 54321
106TEST_GID = 12345
107# Offset to store the map key in stack register REG10
108key_offset = -8
109# Offset to store the map value in stack register REG10
110value_offset = -16
111
112
113# Debug usage only.
114def PrintMapInfo(map_fd):
115  # A random key that the map does not contain.
116  key = 10086
117  while 1:
118    try:
119      next_key = GetNextKey(map_fd, key).value
120      value = LookupMap(map_fd, next_key)
121      print(repr(next_key) + " : " + repr(value.value))  # pylint: disable=superfluous-parens
122      key = next_key
123    except socket.error:
124      print("no value")  # pylint: disable=superfluous-parens
125      break
126
127
128# A dummy loopback function that causes a socket to send traffic to itself.
129def SocketUDPLoopBack(packet_count, version, prog_fd):
130  family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
131  sock = socket.socket(family, socket.SOCK_DGRAM, 0)
132  if prog_fd is not None:
133    BpfProgAttachSocket(sock.fileno(), prog_fd)
134  net_test.SetNonBlocking(sock)
135  addr = {4: "127.0.0.1", 6: "::1"}[version]
136  sock.bind((addr, 0))
137  addr = sock.getsockname()
138  sockaddr = csocket.Sockaddr(addr)
139  for _ in range(packet_count):
140    sock.sendto("foo", addr)
141    data, retaddr = csocket.Recvfrom(sock, 4096, 0)
142    assert "foo" == data
143    assert sockaddr == retaddr
144  return sock
145
146
147# The main code block for eBPF packet counting program. It takes a preloaded
148# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
149# exist in the map yet, the program will update the map with a new <key, 1>
150# pair. Otherwise it will jump to next code block to handle it.
151# REG0: regiter storing return value from helper function and the final return
152# value of eBPF program.
153# REG1 - REG5: temporary register used for storing values and load parameters
154# into eBPF helper function. After calling helper function, the value for these
155# registers will be reset.
156# REG6 - REG9: registers store values that will not be cleared when calling
157# eBPF helper function.
158# REG10: A stack stores values need to be accessed by the address. Program can
159# retrieve the address of a value by specifying the position of the value in
160# the stack.
161def BpfFuncCountPacketInit(map_fd):
162  key_pos = BPF_REG_7
163  return [
164      # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
165      BpfMov64Reg(key_pos, BPF_REG_10),
166      BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
167      # Load map fd and look up the key in the map
168      BpfLoadMapFd(map_fd, BPF_REG_1),
169      BpfMov64Reg(BPF_REG_2, key_pos),
170      BpfFuncCall(BPF_FUNC_map_lookup_elem),
171      # if the map element already exist, jump out of this
172      # code block and let next part to handle it
173      BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
174      BpfLoadMapFd(map_fd, BPF_REG_1),
175      BpfMov64Reg(BPF_REG_2, key_pos),
176      # Initial a new <key, value> pair with value equal to 1 and update to map
177      BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
178      BpfMov64Reg(BPF_REG_3, BPF_REG_10),
179      BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
180      BpfMov64Imm(BPF_REG_4, 0),
181      BpfFuncCall(BPF_FUNC_map_update_elem)
182  ]
183
184
185INS_BPF_EXIT_BLOCK = [
186    BpfMov64Imm(BPF_REG_0, 0),
187    BpfExitInsn()
188]
189
190# Bpf instruction for cgroup bpf filter to accept a packet and exit.
191INS_CGROUP_ACCEPT = [
192    # Set return value to 1 and exit.
193    BpfMov64Imm(BPF_REG_0, 1),
194    BpfExitInsn()
195]
196
197# Bpf instruction for socket bpf filter to accept a packet and exit.
198INS_SK_FILTER_ACCEPT = [
199    # Precondition: BPF_REG_6 = sk_buff context
200    # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
201    # return value.
202    BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
203    BpfExitInsn()
204]
205
206# Update a existing map element with +1.
207INS_PACK_COUNT_UPDATE = [
208    # Precondition: BPF_REG_0 = Value retrieved from BPF maps
209    # Add one to the corresponding eBPF value field for a specific eBPF key.
210    BpfMov64Reg(BPF_REG_2, BPF_REG_0),
211    BpfMov64Imm(BPF_REG_1, 1),
212    BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
213]
214
215INS_BPF_PARAM_STORE = [
216    BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
217]
218
219
220@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
221                     "BPF helper function is not fully supported")
222class BpfTest(net_test.NetworkTest):
223
224  def setUp(self):
225    super(BpfTest, self).setUp()
226    self.map_fd = -1
227    self.prog_fd = -1
228    self.sock = None
229
230  def tearDown(self):
231    if self.prog_fd >= 0:
232      os.close(self.prog_fd)
233    if self.map_fd >= 0:
234      os.close(self.map_fd)
235    if self.sock:
236      self.sock.close()
237    super(BpfTest, self).tearDown()
238
239  def testCreateMap(self):
240    key, value = 1, 1
241    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
242                            TOTAL_ENTRIES)
243    UpdateMap(self.map_fd, key, value)
244    self.assertEqual(value, LookupMap(self.map_fd, key).value)
245    DeleteMap(self.map_fd, key)
246    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
247
248  def CheckAllMapEntry(self, nonexistent_key, total_entries, value):
249    count = 0
250    key = nonexistent_key
251    while True:
252      if count == total_entries:
253        self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
254        break
255      else:
256        result = GetNextKey(self.map_fd, key)
257        key = result.value
258        self.assertGreaterEqual(key, 0)
259        self.assertEqual(value, LookupMap(self.map_fd, key).value)
260        count += 1
261
262  def testIterateMap(self):
263    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
264                            TOTAL_ENTRIES)
265    value = 1024
266    for key in range(0, TOTAL_ENTRIES):
267      UpdateMap(self.map_fd, key, value)
268    for key in range(0, TOTAL_ENTRIES):
269      self.assertEqual(value, LookupMap(self.map_fd, key).value)
270    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
271    nonexistent_key = -1
272    self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value)
273
274  def testFindFirstMapKey(self):
275    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
276                            TOTAL_ENTRIES)
277    value = 1024
278    for key in range(0, TOTAL_ENTRIES):
279      UpdateMap(self.map_fd, key, value)
280    first_key = GetFirstKey(self.map_fd)
281    key = first_key.value
282    self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value)
283
284  def testRdOnlyMap(self):
285    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
286                            TOTAL_ENTRIES, map_flags=BPF_F_RDONLY)
287    value = 1024
288    key = 1
289    self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value)
290    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
291
292  def testWrOnlyMap(self):
293    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
294                            TOTAL_ENTRIES, map_flags=BPF_F_WRONLY)
295    value = 1024
296    key = 1
297    UpdateMap(self.map_fd, key, value)
298    self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key)
299
300  def testProgLoad(self):
301    # Move skb to BPF_REG_6 for further usage
302    instructions = [
303        BpfMov64Reg(BPF_REG_6, BPF_REG_1)
304    ]
305    instructions += INS_SK_FILTER_ACCEPT
306    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
307    SocketUDPLoopBack(1, 4, self.prog_fd)
308    SocketUDPLoopBack(1, 6, self.prog_fd)
309
310  def testPacketBlock(self):
311    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
312    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
313    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
314
315  def testPacketCount(self):
316    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
317                            TOTAL_ENTRIES)
318    key = 0xf0f0
319    # Set up instruction block with key loaded at BPF_REG_0.
320    instructions = [
321        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
322        BpfMov64Imm(BPF_REG_0, key)
323    ]
324    # Concatenate the generic packet count bpf program to it.
325    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
326                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
327                     + INS_SK_FILTER_ACCEPT)
328    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
329    packet_count = 10
330    SocketUDPLoopBack(packet_count, 4, self.prog_fd)
331    SocketUDPLoopBack(packet_count, 6, self.prog_fd)
332    self.assertEqual(packet_count * 2, LookupMap(self.map_fd, key).value)
333
334  ##############################################################################
335  #
336  # Test for presence of kernel patch:
337  #
338  #   ANDROID: net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
339  #
340  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1237789
341  #       commit fe82848d9c1c887d2a84d3738c13e644d01b6d6f
342  #
343  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1237788
344  #       commit 6e04d94ab72435b45c413daff63520fd724e260e
345  #
346  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1237787
347  #       commit d730995e7bc5b4c10cc176235b704a274e6ec16f
348  #
349  # Upstream in Linux v5.8:
350  #   net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
351  #   commit 6f3f65d80dac8f2bafce2213005821fccdce194c
352  #
353  @unittest.skipUnless(bpf.HAVE_EBPF_4_14,
354                       "no bpf_skb_change_head() support for pre-4.14 kernels")
355  def testSkbChangeHead(self):
356    # long bpf_skb_change_head(struct sk_buff *skb, u32 len, u64 flags)
357    instructions = [
358        BpfMov64Imm(BPF_REG_2, 14),  # u32 len
359        BpfMov64Imm(BPF_REG_3, 0),   # u64 flags
360        BpfFuncCall(BPF_FUNC_skb_change_head),
361    ] + INS_BPF_EXIT_BLOCK
362    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
363                               b"Apache 2.0")
364    # No exceptions? Good.
365
366  def testKtimeGetNsGPL(self):
367    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
368    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions)
369    # No exceptions? Good.
370
371  ##############################################################################
372  #
373  # Test for presence of kernel patch:
374  #
375  #   UPSTREAM: net: bpf: Make bpf_ktime_get_ns() available to non GPL programs
376  #
377  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585269
378  #       commit cbb4c73f9eab8f3c8ac29175d45c99ccba382e15
379  #
380  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1355243
381  #       commit 272e21ccc9a92feeee80aff0587410a314b73c5b
382  #
383  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1355422
384  #       commit 45217b91eaaa3a563247c4f470f4cb785de6b1c6
385  #
386  @unittest.skipUnless(HAVE_EBPF_KTIME_GET_NS_APACHE2,
387                       "no bpf_ktime_get_ns() support for non-GPL programs")
388  def testKtimeGetNsApache2(self):
389    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
390    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
391                               b"Apache 2.0")
392    # No exceptions? Good.
393
394  ##############################################################################
395  #
396  # Test for presence of kernel patch:
397  #
398  #   BACKPORT: bpf: add bpf_ktime_get_boot_ns()
399  #
400  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585587
401  #       commit 34073d7a8ee47ca908b56e9a1d14ca0615fdfc09
402  #
403  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1585606
404  #       commit 4812ec50935dfe59ba9f48a572e278dd0b02af68
405  #
406  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1585252
407  #       commit 57b3f4830fb66a6038c4c1c66ca2e138fe8be231
408  #
409  @unittest.skipUnless(HAVE_EBPF_KTIME_GET_BOOT_NS,
410                       "no bpf_ktime_get_boot_ns() support")
411  def testKtimeGetBootNs(self):
412    instructions = [
413        BpfFuncCall(BPF_FUNC_ktime_get_boot_ns),
414    ] + INS_BPF_EXIT_BLOCK
415    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
416                               b"Apache 2.0")
417    # No exceptions? Good.
418
419  def testGetSocketCookie(self):
420    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
421                            TOTAL_ENTRIES)
422    # Move skb to REG6 for further usage, call helper function to get socket
423    # cookie of current skb and return the cookie at REG0 for next code block
424    instructions = [
425        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
426        BpfFuncCall(BPF_FUNC_get_socket_cookie)
427    ]
428    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
429                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
430                     + INS_SK_FILTER_ACCEPT)
431    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
432    packet_count = 10
433    def PacketCountByCookie(version):
434      self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
435      cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
436      self.assertEqual(packet_count, LookupMap(self.map_fd, cookie).value)
437      self.sock.close()
438    PacketCountByCookie(4)
439    PacketCountByCookie(6)
440
441  def testGetSocketUid(self):
442    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
443                            TOTAL_ENTRIES)
444    # Set up the instruction with uid at BPF_REG_0.
445    instructions = [
446        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
447        BpfFuncCall(BPF_FUNC_get_socket_uid)
448    ]
449    # Concatenate the generic packet count bpf program to it.
450    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
451                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
452                     + INS_SK_FILTER_ACCEPT)
453    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
454    packet_count = 10
455    uid = TEST_UID
456    with net_test.RunAsUid(uid):
457      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
458      SocketUDPLoopBack(packet_count, 4, self.prog_fd)
459      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
460      DeleteMap(self.map_fd, uid)
461      SocketUDPLoopBack(packet_count, 6, self.prog_fd)
462      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
463
464
465@unittest.skipUnless(HAVE_EBPF_ACCOUNTING,
466                     "Cgroup BPF is not fully supported")
467class BpfCgroupTest(net_test.NetworkTest):
468
469  @classmethod
470  def setUpClass(cls):
471    super(BpfCgroupTest, cls).setUpClass()
472    cls._cg_dir = tempfile.mkdtemp(prefix="cg_bpf-")
473    cmd = "mount -t cgroup2 cg_bpf %s" % cls._cg_dir
474    try:
475      subprocess.check_call(cmd.split())
476    except subprocess.CalledProcessError:
477      # If an exception is thrown in setUpClass, the test fails and
478      # tearDownClass is not called.
479      os.rmdir(cls._cg_dir)
480      raise
481    cls._cg_fd = os.open(cls._cg_dir, os.O_DIRECTORY | os.O_RDONLY)
482
483  @classmethod
484  def tearDownClass(cls):
485    os.close(cls._cg_fd)
486    subprocess.call(("umount %s" % cls._cg_dir).split())
487    os.rmdir(cls._cg_dir)
488    super(BpfCgroupTest, cls).tearDownClass()
489
490  def setUp(self):
491    super(BpfCgroupTest, self).setUp()
492    self.prog_fd = -1
493    self.map_fd = -1
494
495  def tearDown(self):
496    if self.prog_fd >= 0:
497      os.close(self.prog_fd)
498    if self.map_fd >= 0:
499      os.close(self.map_fd)
500    try:
501      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
502    except socket.error:
503      pass
504    try:
505      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
506    except socket.error:
507      pass
508    try:
509      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
510    except socket.error:
511      pass
512    super(BpfCgroupTest, self).tearDown()
513
514  def testCgroupBpfAttach(self):
515    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
516    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
517    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
518
519  def testCgroupIngress(self):
520    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
521    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
522    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
523    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
524    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
525    SocketUDPLoopBack(1, 4, None)
526    SocketUDPLoopBack(1, 6, None)
527
528  def testCgroupEgress(self):
529    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
530    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
531    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
532    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
533    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
534    SocketUDPLoopBack(1, 4, None)
535    SocketUDPLoopBack(1, 6, None)
536
537  def testCgroupBpfUid(self):
538    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
539                            TOTAL_ENTRIES)
540    # Similar to the program used in testGetSocketUid.
541    instructions = [
542        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
543        BpfFuncCall(BPF_FUNC_get_socket_uid)
544    ]
545    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
546                     + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE
547                     + INS_CGROUP_ACCEPT)
548    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
549    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
550    packet_count = 20
551    uid = TEST_UID
552    with net_test.RunAsUid(uid):
553      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
554      SocketUDPLoopBack(packet_count, 4, None)
555      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
556      DeleteMap(self.map_fd, uid)
557      SocketUDPLoopBack(packet_count, 6, None)
558      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
559    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
560
561  def checkSocketCreate(self, family, socktype, success):
562    try:
563      sock = socket.socket(family, socktype, 0)
564      sock.close()
565    except socket.error as e:
566      if success:
567        self.fail("Failed to create socket family=%d type=%d err=%s" %
568                  (family, socktype, os.strerror(e.errno)))
569      return
570    if not success:
571      self.fail("unexpected socket family=%d type=%d created, should be blocked"
572                % (family, socktype))
573
574  def trySocketCreate(self, success):
575    for family in [socket.AF_INET, socket.AF_INET6]:
576      for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
577        self.checkSocketCreate(family, socktype, success)
578
579  @unittest.skipUnless(HAVE_EBPF_SOCKET,
580                       "Cgroup BPF socket is not supported")
581  def testCgroupSocketCreateBlock(self):
582    instructions = [
583        BpfFuncCall(BPF_FUNC_get_current_uid_gid),
584        BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
585        BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
586    ]
587    instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT
588    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
589    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
590    with net_test.RunAsUid(TEST_UID):
591      # Socket creation with target uid should fail
592      self.trySocketCreate(False)
593    # Socket create with different uid should success
594    self.trySocketCreate(True)
595    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
596    with net_test.RunAsUid(TEST_UID):
597      self.trySocketCreate(True)
598
599if __name__ == "__main__":
600  unittest.main()
601