1#!/usr/bin/python 2# 3# Copyright 2016 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import ctypes 18import errno 19import os 20import socket 21import struct 22import subprocess 23import tempfile 24import unittest 25 26from bpf import * # pylint: disable=wildcard-import 27import csocket 28import net_test 29import sock_diag 30 31libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True) 32HAVE_EBPF_ACCOUNTING = net_test.LINUX_VERSION >= (4, 9, 0) 33HAVE_EBPF_SOCKET = net_test.LINUX_VERSION >= (4, 14, 0) 34KEY_SIZE = 8 35VALUE_SIZE = 4 36TOTAL_ENTRIES = 20 37TEST_UID = 54321 38TEST_GID = 12345 39# Offset to store the map key in stack register REG10 40key_offset = -8 41# Offset to store the map value in stack register REG10 42value_offset = -16 43 44# Debug usage only. 45def PrintMapInfo(map_fd): 46 # A random key that the map does not contain. 47 key = 10086 48 while 1: 49 try: 50 nextKey = GetNextKey(map_fd, key).value 51 value = LookupMap(map_fd, nextKey) 52 print repr(nextKey) + " : " + repr(value.value) 53 key = nextKey 54 except: 55 print "no value" 56 break 57 58 59# A dummy loopback function that causes a socket to send traffic to itself. 60def SocketUDPLoopBack(packet_count, version, prog_fd): 61 family = {4: socket.AF_INET, 6: socket.AF_INET6}[version] 62 sock = socket.socket(family, socket.SOCK_DGRAM, 0) 63 if prog_fd is not None: 64 BpfProgAttachSocket(sock.fileno(), prog_fd) 65 net_test.SetNonBlocking(sock) 66 addr = {4: "127.0.0.1", 6: "::1"}[version] 67 sock.bind((addr, 0)) 68 addr = sock.getsockname() 69 sockaddr = csocket.Sockaddr(addr) 70 for i in xrange(packet_count): 71 sock.sendto("foo", addr) 72 data, retaddr = csocket.Recvfrom(sock, 4096, 0) 73 assert "foo" == data 74 assert sockaddr == retaddr 75 return sock 76 77 78# The main code block for eBPF packet counting program. It takes a preloaded 79# key from BPF_REG_0 and use it to look up the bpf map, if the element does not 80# exist in the map yet, the program will update the map with a new <key, 1> 81# pair. Otherwise it will jump to next code block to handle it. 82# REG0: regiter storing return value from helper function and the final return 83# value of eBPF program. 84# REG1 - REG5: temporary register used for storing values and load parameters 85# into eBPF helper function. After calling helper function, the value for these 86# registers will be reset. 87# REG6 - REG9: registers store values that will not be cleared when calling 88# eBPF helper function. 89# REG10: A stack stores values need to be accessed by the address. Program can 90# retrieve the address of a value by specifying the position of the value in 91# the stack. 92def BpfFuncCountPacketInit(map_fd): 93 key_pos = BPF_REG_7 94 insPackCountStart = [ 95 # Get a preloaded key from BPF_REG_0 and store it at BPF_REG_7 96 BpfMov64Reg(key_pos, BPF_REG_10), 97 BpfAlu64Imm(BPF_ADD, key_pos, key_offset), 98 # Load map fd and look up the key in the map 99 BpfLoadMapFd(map_fd, BPF_REG_1), 100 BpfMov64Reg(BPF_REG_2, key_pos), 101 BpfFuncCall(BPF_FUNC_map_lookup_elem), 102 # if the map element already exist, jump out of this 103 # code block and let next part to handle it 104 BpfJumpImm(BPF_AND, BPF_REG_0, 0, 10), 105 BpfLoadMapFd(map_fd, BPF_REG_1), 106 BpfMov64Reg(BPF_REG_2, key_pos), 107 # Initial a new <key, value> pair with value equal to 1 and update to map 108 BpfStMem(BPF_W, BPF_REG_10, value_offset, 1), 109 BpfMov64Reg(BPF_REG_3, BPF_REG_10), 110 BpfAlu64Imm(BPF_ADD, BPF_REG_3, value_offset), 111 BpfMov64Imm(BPF_REG_4, 0), 112 BpfFuncCall(BPF_FUNC_map_update_elem) 113 ] 114 return insPackCountStart 115 116 117INS_BPF_EXIT_BLOCK = [ 118 BpfMov64Imm(BPF_REG_0, 0), 119 BpfExitInsn() 120] 121 122# Bpf instruction for cgroup bpf filter to accept a packet and exit. 123INS_CGROUP_ACCEPT = [ 124 # Set return value to 1 and exit. 125 BpfMov64Imm(BPF_REG_0, 1), 126 BpfExitInsn() 127] 128 129# Bpf instruction for socket bpf filter to accept a packet and exit. 130INS_SK_FILTER_ACCEPT = [ 131 # Precondition: BPF_REG_6 = sk_buff context 132 # Load the packet length from BPF_REG_6 and store it in BPF_REG_0 as the 133 # return value. 134 BpfLdxMem(BPF_W, BPF_REG_0, BPF_REG_6, 0), 135 BpfExitInsn() 136] 137 138# Update a existing map element with +1. 139INS_PACK_COUNT_UPDATE = [ 140 # Precondition: BPF_REG_0 = Value retrieved from BPF maps 141 # Add one to the corresponding eBPF value field for a specific eBPF key. 142 BpfMov64Reg(BPF_REG_2, BPF_REG_0), 143 BpfMov64Imm(BPF_REG_1, 1), 144 BpfRawInsn(BPF_STX | BPF_XADD | BPF_W, BPF_REG_2, BPF_REG_1, 0, 0), 145] 146 147INS_BPF_PARAM_STORE = [ 148 BpfStxMem(BPF_DW, BPF_REG_10, BPF_REG_0, key_offset), 149] 150 151@unittest.skipUnless(HAVE_EBPF_ACCOUNTING, 152 "BPF helper function is not fully supported") 153class BpfTest(net_test.NetworkTest): 154 155 def setUp(self): 156 self.map_fd = -1 157 self.prog_fd = -1 158 self.sock = None 159 160 def tearDown(self): 161 if self.prog_fd >= 0: 162 os.close(self.prog_fd) 163 if self.map_fd >= 0: 164 os.close(self.map_fd) 165 if self.sock: 166 self.sock.close() 167 168 def testCreateMap(self): 169 key, value = 1, 1 170 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 171 TOTAL_ENTRIES) 172 UpdateMap(self.map_fd, key, value) 173 self.assertEquals(value, LookupMap(self.map_fd, key).value) 174 DeleteMap(self.map_fd, key) 175 self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key) 176 177 def CheckAllMapEntry(self, nonexistent_key, totalEntries, value): 178 count = 0 179 key = nonexistent_key 180 while True: 181 if count == totalEntries: 182 self.assertRaisesErrno(errno.ENOENT, GetNextKey, self.map_fd, key) 183 break 184 else: 185 result = GetNextKey(self.map_fd, key) 186 key = result.value 187 self.assertGreaterEqual(key, 0) 188 self.assertEquals(value, LookupMap(self.map_fd, key).value) 189 count += 1 190 191 def testIterateMap(self): 192 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 193 TOTAL_ENTRIES) 194 value = 1024 195 for key in xrange(0, TOTAL_ENTRIES): 196 UpdateMap(self.map_fd, key, value) 197 for key in xrange(0, TOTAL_ENTRIES): 198 self.assertEquals(value, LookupMap(self.map_fd, key).value) 199 self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, 101) 200 nonexistent_key = -1 201 self.CheckAllMapEntry(nonexistent_key, TOTAL_ENTRIES, value) 202 203 def testFindFirstMapKey(self): 204 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 205 TOTAL_ENTRIES) 206 value = 1024 207 for key in xrange(0, TOTAL_ENTRIES): 208 UpdateMap(self.map_fd, key, value) 209 firstKey = GetFirstKey(self.map_fd) 210 key = firstKey.value 211 self.CheckAllMapEntry(key, TOTAL_ENTRIES - 1, value) 212 213 214 def testRdOnlyMap(self): 215 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 216 TOTAL_ENTRIES, map_flags=BPF_F_RDONLY) 217 value = 1024 218 key = 1 219 self.assertRaisesErrno(errno.EPERM, UpdateMap, self.map_fd, key, value) 220 self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, key) 221 222 @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, 223 "BPF helper function is not fully supported") 224 def testWrOnlyMap(self): 225 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 226 TOTAL_ENTRIES, map_flags=BPF_F_WRONLY) 227 value = 1024 228 key = 1 229 UpdateMap(self.map_fd, key, value) 230 self.assertRaisesErrno(errno.EPERM, LookupMap, self.map_fd, key) 231 232 def testProgLoad(self): 233 # Move skb to BPF_REG_6 for further usage 234 instructions = [ 235 BpfMov64Reg(BPF_REG_6, BPF_REG_1) 236 ] 237 instructions += INS_SK_FILTER_ACCEPT 238 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) 239 SocketUDPLoopBack(1, 4, self.prog_fd) 240 SocketUDPLoopBack(1, 6, self.prog_fd) 241 242 def testPacketBlock(self): 243 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, INS_BPF_EXIT_BLOCK) 244 self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, self.prog_fd) 245 self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, self.prog_fd) 246 247 def testPacketCount(self): 248 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 249 TOTAL_ENTRIES) 250 key = 0xf0f0 251 # Set up instruction block with key loaded at BPF_REG_0. 252 instructions = [ 253 BpfMov64Reg(BPF_REG_6, BPF_REG_1), 254 BpfMov64Imm(BPF_REG_0, key) 255 ] 256 # Concatenate the generic packet count bpf program to it. 257 instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) 258 + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE 259 + INS_SK_FILTER_ACCEPT) 260 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) 261 packet_count = 10 262 SocketUDPLoopBack(packet_count, 4, self.prog_fd) 263 SocketUDPLoopBack(packet_count, 6, self.prog_fd) 264 self.assertEquals(packet_count * 2, LookupMap(self.map_fd, key).value) 265 266 @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, 267 "BPF helper function is not fully supported") 268 def testGetSocketCookie(self): 269 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 270 TOTAL_ENTRIES) 271 # Move skb to REG6 for further usage, call helper function to get socket 272 # cookie of current skb and return the cookie at REG0 for next code block 273 instructions = [ 274 BpfMov64Reg(BPF_REG_6, BPF_REG_1), 275 BpfFuncCall(BPF_FUNC_get_socket_cookie) 276 ] 277 instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) 278 + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE 279 + INS_SK_FILTER_ACCEPT) 280 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) 281 packet_count = 10 282 def PacketCountByCookie(version): 283 self.sock = SocketUDPLoopBack(packet_count, version, self.prog_fd) 284 cookie = sock_diag.SockDiag.GetSocketCookie(self.sock) 285 self.assertEquals(packet_count, LookupMap(self.map_fd, cookie).value) 286 self.sock.close() 287 PacketCountByCookie(4) 288 PacketCountByCookie(6) 289 290 @unittest.skipUnless(HAVE_EBPF_ACCOUNTING, 291 "BPF helper function is not fully supported") 292 def testGetSocketUid(self): 293 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 294 TOTAL_ENTRIES) 295 # Set up the instruction with uid at BPF_REG_0. 296 instructions = [ 297 BpfMov64Reg(BPF_REG_6, BPF_REG_1), 298 BpfFuncCall(BPF_FUNC_get_socket_uid) 299 ] 300 # Concatenate the generic packet count bpf program to it. 301 instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) 302 + INS_SK_FILTER_ACCEPT + INS_PACK_COUNT_UPDATE 303 + INS_SK_FILTER_ACCEPT) 304 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_SOCKET_FILTER, instructions) 305 packet_count = 10 306 uid = TEST_UID 307 with net_test.RunAsUid(uid): 308 self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid) 309 SocketUDPLoopBack(packet_count, 4, self.prog_fd) 310 self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) 311 DeleteMap(self.map_fd, uid); 312 SocketUDPLoopBack(packet_count, 6, self.prog_fd) 313 self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) 314 315@unittest.skipUnless(HAVE_EBPF_ACCOUNTING, 316 "Cgroup BPF is not fully supported") 317class BpfCgroupTest(net_test.NetworkTest): 318 319 @classmethod 320 def setUpClass(cls): 321 cls._cg_dir = tempfile.mkdtemp(prefix="cg_bpf-") 322 cmd = "mount -t cgroup2 cg_bpf %s" % cls._cg_dir 323 try: 324 subprocess.check_call(cmd.split()) 325 except subprocess.CalledProcessError: 326 # If an exception is thrown in setUpClass, the test fails and 327 # tearDownClass is not called. 328 os.rmdir(cls._cg_dir) 329 raise 330 cls._cg_fd = os.open(cls._cg_dir, os.O_DIRECTORY | os.O_RDONLY) 331 332 @classmethod 333 def tearDownClass(cls): 334 os.close(cls._cg_fd) 335 subprocess.call(('umount %s' % cls._cg_dir).split()) 336 os.rmdir(cls._cg_dir) 337 338 def setUp(self): 339 self.prog_fd = -1 340 self.map_fd = -1 341 342 def tearDown(self): 343 if self.prog_fd >= 0: 344 os.close(self.prog_fd) 345 if self.map_fd >= 0: 346 os.close(self.map_fd) 347 try: 348 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS) 349 except socket.error: 350 pass 351 try: 352 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) 353 except socket.error: 354 pass 355 try: 356 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) 357 except socket.error: 358 pass 359 360 def testCgroupBpfAttach(self): 361 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) 362 BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) 363 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) 364 365 def testCgroupIngress(self): 366 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) 367 BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) 368 self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 4, None) 369 self.assertRaisesErrno(errno.EAGAIN, SocketUDPLoopBack, 1, 6, None) 370 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) 371 SocketUDPLoopBack(1, 4, None) 372 SocketUDPLoopBack(1, 6, None) 373 374 def testCgroupEgress(self): 375 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, INS_BPF_EXIT_BLOCK) 376 BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_EGRESS) 377 self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 4, None) 378 self.assertRaisesErrno(errno.EPERM, SocketUDPLoopBack, 1, 6, None) 379 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_EGRESS) 380 SocketUDPLoopBack( 1, 4, None) 381 SocketUDPLoopBack( 1, 6, None) 382 383 def testCgroupBpfUid(self): 384 self.map_fd = CreateMap(BPF_MAP_TYPE_HASH, KEY_SIZE, VALUE_SIZE, 385 TOTAL_ENTRIES) 386 # Similar to the program used in testGetSocketUid. 387 instructions = [ 388 BpfMov64Reg(BPF_REG_6, BPF_REG_1), 389 BpfFuncCall(BPF_FUNC_get_socket_uid) 390 ] 391 instructions += (INS_BPF_PARAM_STORE + BpfFuncCountPacketInit(self.map_fd) 392 + INS_CGROUP_ACCEPT + INS_PACK_COUNT_UPDATE + INS_CGROUP_ACCEPT) 393 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SKB, instructions) 394 BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_INGRESS) 395 packet_count = 20 396 uid = TEST_UID 397 with net_test.RunAsUid(uid): 398 self.assertRaisesErrno(errno.ENOENT, LookupMap, self.map_fd, uid) 399 SocketUDPLoopBack(packet_count, 4, None) 400 self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) 401 DeleteMap(self.map_fd, uid) 402 SocketUDPLoopBack(packet_count, 6, None) 403 self.assertEquals(packet_count, LookupMap(self.map_fd, uid).value) 404 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_INGRESS) 405 406 def checkSocketCreate(self, family, socktype, success): 407 try: 408 sock = socket.socket(family, socktype, 0) 409 sock.close() 410 except socket.error, e: 411 if success: 412 self.fail("Failed to create socket family=%d type=%d err=%s" % 413 (family, socktype, os.strerror(e.errno))) 414 return; 415 if not success: 416 self.fail("unexpected socket family=%d type=%d created, should be blocked" % 417 (family, socktype)) 418 419 420 def trySocketCreate(self, success): 421 for family in [socket.AF_INET, socket.AF_INET6]: 422 for socktype in [socket.SOCK_DGRAM, socket.SOCK_STREAM]: 423 self.checkSocketCreate(family, socktype, success) 424 425 @unittest.skipUnless(HAVE_EBPF_SOCKET, 426 "Cgroup BPF socket is not supported") 427 def testCgroupSocketCreateBlock(self): 428 instructions = [ 429 BpfFuncCall(BPF_FUNC_get_current_uid_gid), 430 BpfAlu64Imm(BPF_AND, BPF_REG_0, 0xfffffff), 431 BpfJumpImm(BPF_JNE, BPF_REG_0, TEST_UID, 2), 432 ] 433 instructions += INS_BPF_EXIT_BLOCK + INS_CGROUP_ACCEPT; 434 self.prog_fd = BpfProgLoad(BPF_PROG_TYPE_CGROUP_SOCK, instructions) 435 BpfProgAttach(self.prog_fd, self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) 436 with net_test.RunAsUid(TEST_UID): 437 # Socket creation with target uid should fail 438 self.trySocketCreate(False); 439 # Socket create with different uid should success 440 self.trySocketCreate(True) 441 BpfProgDetach(self._cg_fd, BPF_CGROUP_INET_SOCK_CREATE) 442 with net_test.RunAsUid(TEST_UID): 443 self.trySocketCreate(True) 444 445if __name__ == "__main__": 446 unittest.main() 447