1#!/usr/bin/python 2# 3# Copyright 2015 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of sock_diag functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import errno 22import os 23from socket import * # pylint: disable=wildcard-import 24import struct 25 26import csocket 27import cstruct 28import net_test 29import netlink 30 31### sock_diag constants. See include/uapi/linux/sock_diag.h. 32# Message types. 33SOCK_DIAG_BY_FAMILY = 20 34SOCK_DESTROY = 21 35 36### inet_diag_constants. See include/uapi/linux/inet_diag.h 37# Message types. 38TCPDIAG_GETSOCK = 18 39 40# Request attributes. 41INET_DIAG_REQ_BYTECODE = 1 42 43# Extensions. 44INET_DIAG_NONE = 0 45INET_DIAG_MEMINFO = 1 46INET_DIAG_INFO = 2 47INET_DIAG_VEGASINFO = 3 48INET_DIAG_CONG = 4 49INET_DIAG_TOS = 5 50INET_DIAG_TCLASS = 6 51INET_DIAG_SKMEMINFO = 7 52INET_DIAG_SHUTDOWN = 8 53INET_DIAG_DCTCPINFO = 9 54INET_DIAG_DCTCPINFO = 9 55INET_DIAG_PROTOCOL = 10 56INET_DIAG_SKV6ONLY = 11 57INET_DIAG_LOCALS = 12 58INET_DIAG_PEERS = 13 59INET_DIAG_PAD = 14 60INET_DIAG_MARK = 15 61 62# Bytecode operations. 63INET_DIAG_BC_NOP = 0 64INET_DIAG_BC_JMP = 1 65INET_DIAG_BC_S_GE = 2 66INET_DIAG_BC_S_LE = 3 67INET_DIAG_BC_D_GE = 4 68INET_DIAG_BC_D_LE = 5 69INET_DIAG_BC_AUTO = 6 70INET_DIAG_BC_S_COND = 7 71INET_DIAG_BC_D_COND = 8 72INET_DIAG_BC_DEV_COND = 9 73INET_DIAG_BC_MARK_COND = 10 74 75# Data structure formats. 76# These aren't constants, they're classes. So, pylint: disable=invalid-name 77InetDiagSockId = cstruct.Struct( 78 "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie") 79InetDiagReqV2 = cstruct.Struct( 80 "InetDiagReqV2", "=BBBxIS", "family protocol ext states id", 81 [InetDiagSockId]) 82InetDiagMsg = cstruct.Struct( 83 "InetDiagMsg", "=BBBBSLLLLL", 84 "family state timer retrans id expires rqueue wqueue uid inode", 85 [InetDiagSockId]) 86InetDiagMeminfo = cstruct.Struct( 87 "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") 88InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no") 89InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi", 90 "family prefix_len port") 91InetDiagMarkcond = cstruct.Struct("InetDiagMarkcond", "=II", "mark mask") 92 93SkMeminfo = cstruct.Struct( 94 "SkMeminfo", "=IIIIIIII", 95 "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog") 96TcpInfo = cstruct.Struct( 97 "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII", 98 "state ca_state retransmits probes backoff options wscale " 99 "rto ato snd_mss rcv_mss " 100 "unacked sacked lost retrans fackets " 101 "last_data_sent last_ack_sent last_data_recv last_ack_recv " 102 "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering " 103 "rcv_rtt rcv_space " 104 "total_retrans") # As of linux 3.13, at least. 105 106TCP_TIME_WAIT = 6 107ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT) 108 109 110class SockDiag(netlink.NetlinkSocket): 111 112 NL_DEBUG = [] 113 114 def __init__(self): 115 super(SockDiag, self).__init__(netlink.NETLINK_SOCK_DIAG) 116 117 def _Decode(self, command, msg, nla_type, nla_data): 118 """Decodes netlink attributes to Python types.""" 119 if msg.family == AF_INET or msg.family == AF_INET6: 120 if isinstance(msg, InetDiagReqV2): 121 prefix = "INET_DIAG_REQ" 122 else: 123 prefix = "INET_DIAG" 124 name = self._GetConstantName(__name__, nla_type, prefix) 125 else: 126 # Don't know what this is. Leave it as an integer. 127 name = nla_type 128 129 if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS", 130 "INET_DIAG_SKV6ONLY"]: 131 data = ord(nla_data) 132 elif name == "INET_DIAG_CONG": 133 data = nla_data.strip("\x00") 134 elif name == "INET_DIAG_MEMINFO": 135 data = InetDiagMeminfo(nla_data) 136 elif name == "INET_DIAG_INFO": 137 # TODO: Catch the exception and try something else if it's not TCP. 138 data = TcpInfo(nla_data) 139 elif name == "INET_DIAG_SKMEMINFO": 140 data = SkMeminfo(nla_data) 141 elif name == "INET_DIAG_MARK": 142 data = struct.unpack("=I", nla_data)[0] 143 elif name == "INET_DIAG_REQ_BYTECODE": 144 data = self.DecodeBytecode(nla_data) 145 elif name in ["INET_DIAG_LOCALS", "INET_DIAG_PEERS"]: 146 data = [] 147 while len(nla_data): 148 # The SCTP diag code always appears to copy sizeof(sockaddr_storage) 149 # bytes, but does so from a union sctp_addr which is at most as long 150 # as a sockaddr_in6. 151 addr, nla_data = cstruct.Read(nla_data, csocket.SockaddrStorage) 152 if addr.family == AF_INET: 153 addr = csocket.SockaddrIn(addr.Pack()) 154 elif addr.family == AF_INET6: 155 addr = csocket.SockaddrIn6(addr.Pack()) 156 data.append(addr) 157 else: 158 data = nla_data 159 160 return name, data 161 162 def MaybeDebugCommand(self, command, unused_flags, data): 163 name = self._GetConstantName(__name__, command, "SOCK_") 164 if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: 165 return 166 parsed = self._ParseNLMsg(data, InetDiagReqV2) 167 print("%s %s" % (name, str(parsed))) 168 169 @staticmethod 170 def _EmptyInetDiagSockId(): 171 return InetDiagSockId(("\x00" * len(InetDiagSockId))) 172 173 @staticmethod 174 def PackBytecode(instructions): 175 """Compiles instructions to inet_diag bytecode. 176 177 The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes 178 and no are relative jump offsets measured in instructions. The yes branch 179 is taken if the instruction matches. 180 181 To accept, jump 1 past the last instruction. To reject, jump 2 past the 182 last instruction. 183 184 The target of a no jump is only valid if it is reachable by following 185 only yes jumps from the first instruction - see inet_diag_bc_audit and 186 valid_cc. This means that if cond1 and cond2 are two mutually exclusive 187 filter terms, it is not possible to implement cond1 OR cond2 using: 188 189 ... 190 cond1 2 1 arg 191 cond2 1 2 arg 192 accept 193 reject 194 195 but only using: 196 197 ... 198 cond1 1 2 arg 199 jmp 1 2 200 cond2 1 2 arg 201 accept 202 reject 203 204 The jmp instruction ignores yes and always jumps to no, but yes must be 1 205 or the bytecode won't validate. It doesn't have to be jmp - any instruction 206 that is guaranteed not to match on real data will do. 207 208 Args: 209 instructions: list of instruction tuples 210 211 Returns: 212 A string, the raw bytecode. 213 """ 214 args = [] 215 positions = [0] 216 217 for op, yes, no, arg in instructions: 218 219 if yes <= 0 or no <= 0: 220 raise ValueError("Jumps must be > 0") 221 222 if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 223 arg = "" 224 elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 225 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 226 arg = "\x00\x00" + struct.pack("=H", arg) 227 elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 228 addr, prefixlen, port = arg 229 family = AF_INET6 if ":" in addr else AF_INET 230 addr = inet_pton(family, addr) 231 arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr 232 elif op == INET_DIAG_BC_MARK_COND: 233 if isinstance(arg, tuple): 234 mark, mask = arg 235 else: 236 mark, mask = arg, 0xffffffff 237 arg = InetDiagMarkcond((mark, mask)).Pack() 238 else: 239 raise ValueError("Unsupported opcode %d" % op) 240 241 args.append(arg) 242 length = len(InetDiagBcOp) + len(arg) 243 positions.append(positions[-1] + length) 244 245 # Reject label. 246 positions.append(positions[-1] + 4) # Why 4? Because the kernel uses 4. 247 assert len(args) == len(instructions) == len(positions) - 2 248 249 # print(positions) 250 251 packed = "" 252 for i, (op, yes, no, arg) in enumerate(instructions): 253 yes = positions[i + yes] - positions[i] 254 no = positions[i + no] - positions[i] 255 instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i] 256 #print("%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no, 257 # arg, instruction.encode("hex"))) 258 packed += instruction 259 #print 260 261 return packed 262 263 @staticmethod 264 def DecodeBytecode(bytecode): 265 instructions = [] 266 try: 267 while bytecode: 268 op, rest = cstruct.Read(bytecode, InetDiagBcOp) 269 270 if op.code in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]: 271 arg = None 272 elif op.code in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE, 273 INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]: 274 op, rest = cstruct.Read(rest, InetDiagBcOp) 275 arg = op.no 276 elif op.code in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]: 277 cond, rest = cstruct.Read(rest, InetDiagHostcond) 278 if cond.family == 0: 279 arg = (None, cond.prefix_len, cond.port) 280 else: 281 addrlen = 4 if cond.family == AF_INET else 16 282 addr, rest = rest[:addrlen], rest[addrlen:] 283 addr = inet_ntop(cond.family, addr) 284 arg = (addr, cond.prefix_len, cond.port) 285 elif op.code == INET_DIAG_BC_DEV_COND: 286 attrlen = struct.calcsize("=I") 287 attr, rest = rest[:attrlen], rest[attrlen:] 288 arg = struct.unpack("=I", attr) 289 elif op.code == INET_DIAG_BC_MARK_COND: 290 arg, rest = cstruct.Read(rest, InetDiagMarkcond) 291 else: 292 raise ValueError("Unknown opcode %d" % op.code) 293 instructions.append((op, arg)) 294 bytecode = rest 295 296 return instructions 297 except (TypeError, ValueError): 298 return "???" 299 300 def Dump(self, diag_req, bytecode): 301 if bytecode: 302 bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode) 303 304 out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode) 305 return out 306 307 def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0, 308 states=ALL_NON_TIME_WAIT): 309 """Dumps IPv4 or IPv6 sockets matching the specified parameters.""" 310 # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it 311 # results in ENOENT. 312 if sock_id is None: 313 sock_id = self._EmptyInetDiagSockId() 314 315 sockets = [] 316 for family in [AF_INET, AF_INET6]: 317 diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) 318 sockets += self.Dump(diag_req, bytecode) 319 320 return sockets 321 322 @staticmethod 323 def GetRawAddress(family, addr): 324 """Fetches the source address from an InetDiagMsg.""" 325 addrlen = {AF_INET:4, AF_INET6: 16}[family] 326 return inet_ntop(family, addr[:addrlen]) 327 328 @staticmethod 329 def GetSourceAddress(diag_msg): 330 """Fetches the source address from an InetDiagMsg.""" 331 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src) 332 333 @staticmethod 334 def GetDestinationAddress(diag_msg): 335 """Fetches the source address from an InetDiagMsg.""" 336 return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst) 337 338 @staticmethod 339 def RawAddress(addr): 340 """Converts an IP address string to binary format.""" 341 family = AF_INET6 if ":" in addr else AF_INET 342 return inet_pton(family, addr) 343 344 @staticmethod 345 def PaddedAddress(addr): 346 """Converts an IP address string to binary format for InetDiagSockId.""" 347 padded = SockDiag.RawAddress(addr) 348 if len(padded) < 16: 349 padded += "\x00" * (16 - len(padded)) 350 return padded 351 352 @staticmethod 353 def DiagReqFromSocket(s): 354 """Creates an InetDiagReqV2 that matches the specified socket.""" 355 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) 356 protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL) 357 if net_test.LINUX_VERSION >= (3, 8): 358 iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE, 359 net_test.IFNAMSIZ) 360 iface = GetInterfaceIndex(iface) if iface else 0 361 else: 362 iface = 0 363 src, sport = s.getsockname()[:2] 364 try: 365 dst, dport = s.getpeername()[:2] 366 except error as e: 367 if e.errno == errno.ENOTCONN: 368 dport = 0 369 dst = "::" if family == AF_INET6 else "0.0.0.0" 370 else: 371 raise e 372 src = SockDiag.PaddedAddress(src) 373 dst = SockDiag.PaddedAddress(dst) 374 sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8)) 375 return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) 376 377 @staticmethod 378 def GetSocketCookie(s): 379 cookie = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8) 380 return struct.unpack("=Q", cookie)[0] 381 382 def FindSockInfoFromFd(self, s): 383 """Gets a diag_msg and attrs from the kernel for the specified socket.""" 384 req = self.DiagReqFromSocket(s) 385 # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it 386 # only uses them when targeting a specific socket with a cookie. Check the 387 # the inode number to ensure we don't mistakenly match another socket on 388 # the same port but with a different IP address. 389 inode = os.fstat(s.fileno()).st_ino 390 results = self.Dump(req, "") 391 if len(results) == 0: 392 raise ValueError("Dump of %s returned no sockets" % req) 393 for diag_msg, attrs in results: 394 if diag_msg.inode == inode: 395 return diag_msg, attrs 396 raise ValueError("Dump of %s did not contain inode %d" % (req, inode)) 397 398 def FindSockDiagFromFd(self, s): 399 """Gets an InetDiagMsg from the kernel for the specified socket.""" 400 return self.FindSockInfoFromFd(s)[0] 401 402 def GetSockInfo(self, req): 403 """Gets a diag_msg and attrs from the kernel for the specified request.""" 404 self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST) 405 return self._GetMsg(InetDiagMsg) 406 407 @staticmethod 408 def DiagReqFromDiagMsg(d, protocol): 409 """Constructs a diag_req from a diag_msg the kernel has given us.""" 410 return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) 411 412 def CloseSocket(self, req): 413 self._SendNlRequest(SOCK_DESTROY, req.Pack(), 414 netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) 415 416 def CloseSocketFromFd(self, s): 417 diag_msg, attrs = self.FindSockInfoFromFd(s) 418 protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) 419 req = self.DiagReqFromDiagMsg(diag_msg, protocol) 420 return self.CloseSocket(req) 421 422 423if __name__ == "__main__": 424 n = SockDiag() 425 n.DEBUG = True 426 bytecode = "" 427 sock_id = n._EmptyInetDiagSockId() 428 sock_id.dport = 443 429 ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) 430 states = 0xffffffff 431 diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "", 432 sock_id=sock_id, ext=ext, states=states) 433 print(diag_msgs) 434