• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2016 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import ctypes
18import errno
19import os
20import socket
21import tempfile
22import unittest
23
24import bpf
25from bpf import BPF_ADD
26from bpf import BPF_AND
27from bpf import BPF_CGROUP_INET_EGRESS
28from bpf import BPF_CGROUP_INET_INGRESS
29from bpf import BPF_CGROUP_INET_SOCK_CREATE
30from bpf import BPF_DW
31from bpf import BPF_F_RDONLY
32from bpf import BPF_F_WRONLY
33from bpf import BPF_FUNC_get_current_uid_gid
34from bpf import BPF_FUNC_get_socket_cookie
35from bpf import BPF_FUNC_get_socket_uid
36from bpf import BPF_FUNC_ktime_get_boot_ns
37from bpf import BPF_FUNC_ktime_get_ns
38from bpf import BPF_FUNC_map_lookup_elem
39from bpf import BPF_FUNC_map_update_elem
40from bpf import BPF_FUNC_skb_change_head
41from bpf import BPF_JNE
42from bpf import BPF_MAP_TYPE_ARRAY
43from bpf import BPF_MAP_TYPE_HASH
44from bpf import BPF_PROG_TYPE_CGROUP_SKB
45from bpf import BPF_PROG_TYPE_CGROUP_SOCK
46from bpf import BPF_PROG_TYPE_SCHED_CLS
47from bpf import BPF_PROG_TYPE_SOCKET_FILTER
48from bpf import BPF_REG_0
49from bpf import BPF_REG_1
50from bpf import BPF_REG_10
51from bpf import BPF_REG_2
52from bpf import BPF_REG_3
53from bpf import BPF_REG_4
54from bpf import BPF_REG_6
55from bpf import BPF_REG_7
56from bpf import BPF_STX
57from bpf import BPF_W
58from bpf import BPF_XADD
59from bpf import BpfAlu64Imm
60from bpf import BpfExitInsn
61from bpf import BpfFuncCall
62from bpf import BpfJumpImm
63from bpf import BpfLdxMem
64from bpf import BpfLoadMapFd
65from bpf import BpfMov64Imm
66from bpf import BpfMov64Reg
67from bpf import BpfProgAttach
68from bpf import BpfProgAttachSocket
69from bpf import BpfProgDetach
70from bpf import BpfProgLoad
71from bpf import BpfRawInsn
72from bpf import BpfStMem
73from bpf import BpfStxMem
74from bpf import CreateMap
75from bpf import DeleteMap
76from bpf import GetFirstKey
77from bpf import GetNextKey
78from bpf import LookupMap
79from bpf import UpdateMap
80import csocket
81import net_test
82import sock_diag
83
84libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
85
86KEY_SIZE = 4
87VALUE_SIZE = 4
88TOTAL_ENTRIES = 20
89TEST_UID = 54321
90TEST_GID = 12345
91# Offset to store the map key in stack register REG10
92key_offset = -8
93# Offset to store the map value in stack register REG10
94value_offset = -16
95
96
97# Debug usage only.
98def PrintMapInfo(map_fd):
99  # A random key that the map does not contain.
100  key = 10086
101  while 1:
102    try:
103      next_key = GetNextKey(map_fd, key).value
104      value = LookupMap(map_fd, next_key)
105      print(repr(next_key) + " : " + repr(value.value))  # pylint: disable=superfluous-parens
106      key = next_key
107    except socket.error:
108      print("no value")  # pylint: disable=superfluous-parens
109      break
110
111
112# A dummy loopback function that causes a socket to send traffic to itself.
113def SocketUDPLoopBack(packet_count, version, prog_fd):
114  family = {4: socket.AF_INET, 6: socket.AF_INET6}[version]
115  sock = socket.socket(family, socket.SOCK_DGRAM, 0)
116  try:
117    if prog_fd is not None:
118      BpfProgAttachSocket(sock.fileno(), prog_fd)
119    net_test.SetNonBlocking(sock)
120    addr = {4: "127.0.0.1", 6: "::1"}[version]
121    sock.bind((addr, 0))
122    addr = sock.getsockname()
123    sockaddr = csocket.Sockaddr(addr)
124    for _ in range(packet_count):
125      sock.sendto(b"foo", addr)
126      data, retaddr = csocket.Recvfrom(sock, 4096, 0)
127      assert b"foo" == data
128      assert sockaddr == retaddr
129    return sock
130  except Exception as e:
131    sock.close()
132    raise e
133
134
135# The main code block for eBPF packet counting program. It takes a preloaded
136# key from BPF_REG_0 and use it to look up the bpf map, if the element does not
137# exist in the map yet, the program will update the map with a new <key, 1>
138# pair. Otherwise it will jump to next code block to handle it.
139# REG0: regiter storing return value from helper function and the final return
140# value of eBPF program.
141# REG1 - REG5: temporary register used for storing values and load parameters
142# into eBPF helper function. After calling helper function, the value for these
143# registers will be reset.
144# REG6 - REG9: registers store values that will not be cleared when calling
145# eBPF helper function.
146# REG10: A stack stores values need to be accessed by the address. Program can
147# retrieve the address of a value by specifying the position of the value in
148# the stack.
149def BpfFuncCountPacketInit(map_fd):
150  key_pos = BPF_REG_7
151  return [
152      # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7
153      BpfMov64Reg(key_pos, BPF_REG_10),
154      BpfAlu64Imm(BPF_ADD, key_pos, key_offset),
155      # Load map fd and look up the key in the map
156      BpfLoadMapFd(map_fd, BPF_REG_1),
157      BpfMov64Reg(BPF_REG_2, key_pos),
158      BpfFuncCall(BPF_FUNC_map_lookup_elem),
159      # if the map element already exist, jump out of this
160      # code block and let next part to handle it
161      BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10),
162      BpfLoadMapFd(map_fd, BPF_REG_1),
163      BpfMov64Reg(BPF_REG_2, key_pos),
164      # Initial a new <key, value> pair with value equal to 1 and update to map
165      BpfStMem(BPF_W, BPF_REG_10, value_offset, 1),
166      BpfMov64Reg(BPF_REG_3, BPF_REG_10),
167      BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset),
168      BpfMov64Imm(BPF_REG_4, 0),
169      BpfFuncCall(BPF_FUNC_map_update_elem)
170  ]
171
172
173INS_BPF_EXIT_BLOCK = [
174    BpfMov64Imm(BPF_REG_0, 0),
175    BpfExitInsn()
176]
177
178# Bpf instruction for cgroup bpf filter to accept a packet and exit.
179INS_CGROUP_ACCEPT = [
180    # Set return value to 1 and exit.
181    BpfMov64Imm(BPF_REG_0, 1),
182    BpfExitInsn()
183]
184
185# Bpf instruction for socket bpf filter to accept a packet and exit.
186INS_SK_FILTER_ACCEPT = [
187    # Precondition: BPF_REG_6 = sk_buff context
188    # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the
189    # return value.
190    BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0),
191    BpfExitInsn()
192]
193
194# Update a existing map element with +1.
195INS_PACK_COUNT_UPDATE = [
196    # Precondition: BPF_REG_0 = Value retrieved from BPF maps
197    # Add one to the corresponding eBPF value field for a specific eBPF key.
198    BpfMov64Reg(BPF_REG_2, BPF_REG_0),
199    BpfMov64Imm(BPF_REG_1, 1),
200    BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0),
201]
202
203INS_BPF_PARAM_STORE = [
204    BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset),
205]
206
207
208class BpfTest(net_test.NetworkTest):
209
210  def setUp(self):
211    super(BpfTest, self).setUp()
212    self.map_fd = -1
213    self.prog_fd = -1
214    self.sock = None
215
216  def tearDown(self):
217    if self.prog_fd >= 0:
218      os.close(self.prog_fd)
219      self.prog_fd = -1
220    if self.map_fd >= 0:
221      os.close(self.map_fd)
222      self.map_fd = -1
223    if self.sock:
224      self.sock.close()
225      self.sock = None
226    super(BpfTest, self).tearDown()
227
228  def testCreateMap(self):
229    key, value = 1, 1
230    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
231                            TOTAL_ENTRIES)
232    UpdateMap(self.map_fd, key, value)
233    self.assertEqual(value, LookupMap(self.map_fd, key).value)
234    DeleteMap(self.map_fd, key)
235    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
236
237  def CheckAllMapEntry(self, nonexistent_key, total_entries, value):
238    count = 0
239    key = nonexistent_key
240    while True:
241      if count == total_entries:
242        self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key)
243        break
244      else:
245        result = GetNextKey(self.map_fd, key)
246        key = result.value
247        self.assertGreaterEqual(key, 0)
248        self.assertEqual(value, LookupMap(self.map_fd, key).value)
249        count += 1
250
251  def testIterateMap(self):
252    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
253                            TOTAL_ENTRIES)
254    value = 1024
255    for key in range(0, TOTAL_ENTRIES):
256      UpdateMap(self.map_fd, key, value)
257    for key in range(0, TOTAL_ENTRIES):
258      self.assertEqual(value, LookupMap(self.map_fd, key).value)
259    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101)
260    nonexistent_key = -1
261    self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value)
262
263  def testFindFirstMapKey(self):
264    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
265                            TOTAL_ENTRIES)
266    value = 1024
267    for key in range(0, TOTAL_ENTRIES):
268      UpdateMap(self.map_fd, key, value)
269    first_key = GetFirstKey(self.map_fd)
270    key = first_key.value
271    self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value)
272
273  def testArrayNonZeroOffset(self):
274    self.map_fd = CreateMap(BPF_MAP_TYPE_ARRAY, KEY_SIZE, VALUE_SIZE, 2)
275    key = 1
276    value = 123
277    UpdateMap(self.map_fd, key, value)
278    self.assertEqual(value, LookupMap(self.map_fd, key).value)
279
280  def testRdOnlyMap(self):
281    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
282                            TOTAL_ENTRIES, map_flags=BPF_F_RDONLY)
283    value = 1024
284    key = 1
285    self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value)
286    self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key)
287
288  def testWrOnlyMap(self):
289    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
290                            TOTAL_ENTRIES, map_flags=BPF_F_WRONLY)
291    value = 1024
292    key = 1
293    UpdateMap(self.map_fd, key, value)
294    self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key)
295
296  def testProgLoad(self):
297    # Move skb to BPF_REG_6 for further usage
298    instructions = [
299        BpfMov64Reg(BPF_REG_6, BPF_REG_1)
300    ]
301    instructions += INS_SK_FILTER_ACCEPT
302    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
303    SocketUDPLoopBack(1, 4, self.prog_fd).close()
304    SocketUDPLoopBack(1, 6, self.prog_fd).close()
305
306  def testPacketBlock(self):
307    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK)
308    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd)
309    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd)
310
311  def testPacketCount(self):
312    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
313                            TOTAL_ENTRIES)
314    key = 0xf0f0
315    # Set up instruction block with key loaded at BPF_REG_0.
316    instructions = [
317        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
318        BpfMov64Imm(BPF_REG_0, key)
319    ]
320    # Concatenate the generic packet count bpf program to it.
321    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
322                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
323                     + INS_SK_FILTER_ACCEPT)
324    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
325    packet_count = 10
326    SocketUDPLoopBack(packet_count, 4, self.prog_fd).close()
327    SocketUDPLoopBack(packet_count, 6, self.prog_fd).close()
328    self.assertEqual(packet_count * 2, LookupMap(self.map_fd, key).value)
329
330  ##############################################################################
331  #
332  # Test for presence of kernel patch:
333  #
334  #   ANDROID: net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
335  #
336  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1237789
337  #       commit fe82848d9c1c887d2a84d3738c13e644d01b6d6f
338  #
339  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1237788
340  #       commit 6e04d94ab72435b45c413daff63520fd724e260e
341  #
342  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1237787
343  #       commit d730995e7bc5b4c10cc176235b704a274e6ec16f
344  #
345  # Upstream in Linux v5.8:
346  #   net: bpf: Allow TC programs to call BPF_FUNC_skb_change_head
347  #   commit 6f3f65d80dac8f2bafce2213005821fccdce194c
348  #
349  def testSkbChangeHead(self):
350    # long bpf_skb_change_head(struct sk_buff *skb, u32 len, u64 flags)
351    instructions = [
352        BpfMov64Imm(BPF_REG_2, 14),  # u32 len
353        BpfMov64Imm(BPF_REG_3, 0),   # u64 flags
354        BpfFuncCall(BPF_FUNC_skb_change_head),
355    ] + INS_BPF_EXIT_BLOCK
356    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
357                               b"Apache 2.0")
358    # No exceptions? Good.
359
360  def testKtimeGetNsGPL(self):
361    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
362    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions)
363    # No exceptions? Good.
364
365  ##############################################################################
366  #
367  # Test for presence of kernel patch:
368  #
369  #   UPSTREAM: net: bpf: Make bpf_ktime_get_ns() available to non GPL programs
370  #
371  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585269
372  #       commit cbb4c73f9eab8f3c8ac29175d45c99ccba382e15
373  #
374  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1355243
375  #       commit 272e21ccc9a92feeee80aff0587410a314b73c5b
376  #
377  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1355422
378  #       commit 45217b91eaaa3a563247c4f470f4cb785de6b1c6
379  #
380  def testKtimeGetNsApache2(self):
381    instructions = [BpfFuncCall(BPF_FUNC_ktime_get_ns)] + INS_BPF_EXIT_BLOCK
382    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
383                               b"Apache 2.0")
384    # No exceptions? Good.
385
386  ##############################################################################
387  #
388  # Test for presence of kernel patch:
389  #
390  #   BACKPORT: bpf: add bpf_ktime_get_boot_ns()
391  #
392  # 4.14: https://android-review.googlesource.com/c/kernel/common/+/1585587
393  #       commit 34073d7a8ee47ca908b56e9a1d14ca0615fdfc09
394  #
395  # 4.19: https://android-review.googlesource.com/c/kernel/common/+/1585606
396  #       commit 4812ec50935dfe59ba9f48a572e278dd0b02af68
397  #
398  # 5.4:  https://android-review.googlesource.com/c/kernel/common/+/1585252
399  #       commit 57b3f4830fb66a6038c4c1c66ca2e138fe8be231
400  #
401  def testKtimeGetBootNs(self):
402    instructions = [
403        BpfFuncCall(BPF_FUNC_ktime_get_boot_ns),
404    ] + INS_BPF_EXIT_BLOCK
405    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SCHED_CLS, instructions,
406                               b"Apache 2.0")
407    # No exceptions? Good.
408
409  ##############################################################################
410  #
411  # Test for presence of upstream 5.14 kernel patches:
412  #
413  # Android12-5.10:
414  #   UPSTREAM: net: initialize net->net_cookie at netns setup
415  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503195
416  #
417  #   UPSTREAM: net: retrieve netns cookie via getsocketopt
418  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503056
419  #
420  # (and potentially if you care about kernel ABI)
421  #
422  #   ANDROID: fix ABI by undoing atomic64_t -> u64 type conversion
423  #   https://android-review.git.corp.google.com/c/kernel/common/+/2504335
424  #
425  # Android13-5.10:
426  #   UPSTREAM: net: initialize net->net_cookie at netns setup
427  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503795
428  #
429  #   UPSTREAM: net: retrieve netns cookie via getsocketopt
430  #   https://android-review.git.corp.google.com/c/kernel/common/+/2503796
431  #
432  # (and potentially if you care about kernel ABI)
433  #
434  #   ANDROID: fix ABI by undoing atomic64_t -> u64 type conversion
435  #   https://android-review.git.corp.google.com/c/kernel/common/+/2506895
436  #
437  @unittest.skipUnless(bpf.HAVE_SO_NETNS_COOKIE, "no SO_NETNS_COOKIE support")
438  def testGetNetNsCookie(self):
439    sk = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM, 0)
440    cookie = sk.getsockopt(socket.SOL_SOCKET, bpf.SO_NETNS_COOKIE, 8)  # sizeof(u64) == 8
441    sk.close()
442    self.assertEqual(len(cookie), 8)
443    cookie = int.from_bytes(cookie, "little")
444    self.assertGreaterEqual(cookie, 0)
445
446  def testGetSocketCookie(self):
447    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
448                            TOTAL_ENTRIES)
449    # Move skb to REG6 for further usage, call helper function to get socket
450    # cookie of current skb and return the cookie at REG0 for next code block
451    instructions = [
452        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
453        BpfFuncCall(BPF_FUNC_get_socket_cookie)
454    ]
455    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
456                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
457                     + INS_SK_FILTER_ACCEPT)
458    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
459    packet_count = 10
460    def PacketCountByCookie(version):
461      self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd)
462      cookie = sock_diag.SockDiag.GetSocketCookie(self.sock)
463      self.assertEqual(packet_count, LookupMap(self.map_fd, cookie).value)
464      self.sock.close()
465    PacketCountByCookie(4)
466    PacketCountByCookie(6)
467
468  def testGetSocketUid(self):
469    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
470                            TOTAL_ENTRIES)
471    # Set up the instruction with uid at BPF_REG_0.
472    instructions = [
473        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
474        BpfFuncCall(BPF_FUNC_get_socket_uid)
475    ]
476    # Concatenate the generic packet count bpf program to it.
477    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
478                     + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE
479                     + INS_SK_FILTER_ACCEPT)
480    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions)
481    packet_count = 10
482    uid = TEST_UID
483    with net_test.RunAsUid(uid):
484      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
485      SocketUDPLoopBack(packet_count, 4, self.prog_fd).close()
486      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
487      DeleteMap(self.map_fd, uid)
488      SocketUDPLoopBack(packet_count, 6, self.prog_fd).close()
489      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
490
491
492class BpfCgroupTest(net_test.NetworkTest):
493
494  @classmethod
495  def setUpClass(cls):
496    super(BpfCgroupTest, cls).setUpClass()
497    cls._cg_fd = os.open("/sys/fs/cgroup", os.O_DIRECTORY | os.O_RDONLY)
498
499  @classmethod
500  def tearDownClass(cls):
501    os.close(cls._cg_fd)
502    super(BpfCgroupTest, cls).tearDownClass()
503
504  def setUp(self):
505    super(BpfCgroupTest, self).setUp()
506    self.prog_fd = -1
507    self.map_fd = -1
508
509  def tearDown(self):
510    if self.prog_fd >= 0:
511      os.close(self.prog_fd)
512    if self.map_fd >= 0:
513      os.close(self.map_fd)
514    try:
515      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
516    except socket.error:
517      pass
518    try:
519      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
520    except socket.error:
521      pass
522    try:
523      BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
524    except socket.error:
525      pass
526    super(BpfCgroupTest, self).tearDown()
527
528  def testCgroupBpfAttach(self):
529    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
530    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
531    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
532
533  def testCgroupIngress(self):
534    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
535    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
536    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None)
537    self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None)
538    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
539    SocketUDPLoopBack(1, 4, None).close()
540    SocketUDPLoopBack(1, 6, None).close()
541
542  def testCgroupEgress(self):
543    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK)
544    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS)
545    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None)
546    self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None)
547    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS)
548    SocketUDPLoopBack(1, 4, None).close()
549    SocketUDPLoopBack(1, 6, None).close()
550
551  def testCgroupBpfUid(self):
552    self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE,
553                            TOTAL_ENTRIES)
554    # Similar to the program used in testGetSocketUid.
555    instructions = [
556        BpfMov64Reg(BPF_REG_6, BPF_REG_1),
557        BpfFuncCall(BPF_FUNC_get_socket_uid)
558    ]
559    instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd)
560                     + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE
561                     + INS_CGROUP_ACCEPT)
562    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions)
563    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS)
564    packet_count = 20
565    uid = TEST_UID
566    with net_test.RunAsUid(uid):
567      self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid)
568      SocketUDPLoopBack(packet_count, 4, None).close()
569      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
570      DeleteMap(self.map_fd, uid)
571      SocketUDPLoopBack(packet_count, 6, None).close()
572      self.assertEqual(packet_count, LookupMap(self.map_fd, uid).value)
573    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS)
574
575  def checkSocketCreate(self, family, socktype, success):
576    try:
577      sock = socket.socket(family, socktype, 0)
578      sock.close()
579    except socket.error as e:
580      if success:
581        self.fail("Failed to create socket family=%d type=%d err=%s" %
582                  (family, socktype, os.strerror(e.errno)))
583      return
584    if not success:
585      self.fail("unexpected socket family=%d type=%d created, should be blocked"
586                % (family, socktype))
587
588  def trySocketCreate(self, success):
589    for family in [socket.AF_INET, socket.AF_INET6]:
590      for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]:
591        self.checkSocketCreate(family, socktype, success)
592
593  def testCgroupSocketCreateBlock(self):
594    instructions = [
595        BpfFuncCall(BPF_FUNC_get_current_uid_gid),
596        BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff),
597        BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2),
598    ]
599    instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT
600    self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions)
601    BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
602    with net_test.RunAsUid(TEST_UID):
603      # Socket creation with target uid should fail
604      self.trySocketCreate(False)
605    # Socket create with different uid should success
606    self.trySocketCreate(True)
607    BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE)
608    with net_test.RunAsUid(TEST_UID):
609      self.trySocketCreate(True)
610
611if __name__ == "__main__":
612  unittest.main()
613