1#!/usr/bin/python3 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of iproute functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import os 22import socket 23import struct 24import sys 25 26import cstruct 27import util 28 29### Base netlink constants. See include/uapi/linux/netlink.h. 30NETLINK_ROUTE = 0 31NETLINK_SOCK_DIAG = 4 32NETLINK_XFRM = 6 33NETLINK_GENERIC = 16 34 35# Request constants. 36NLM_F_REQUEST = 1 37NLM_F_ACK = 4 38NLM_F_REPLACE = 0x100 39NLM_F_EXCL = 0x200 40NLM_F_CREATE = 0x400 41NLM_F_DUMP = 0x300 42 43# Message types. 44NLMSG_ERROR = 2 45NLMSG_DONE = 3 46 47# Data structure formats. 48# These aren't constants, they're classes. So, pylint: disable=invalid-name 49NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid") 50NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error") 51NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type") 52 53# Alignment / padding. 54NLA_ALIGNTO = 4 55 56# List of attributes that can appear more than once in a given netlink message. 57# These can appear more than once but don't seem to contain any data. 58DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"] 59 60 61def MakeConstantPrefixes(prefixes): 62 return sorted(prefixes, key=len, reverse=True) 63 64 65class NetlinkSocket(object): 66 """A basic netlink socket object.""" 67 68 BUFSIZE = 65536 69 DEBUG = False 70 # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"] 71 NL_DEBUG = [] 72 73 def _Debug(self, s): 74 if self.DEBUG: 75 print(s) 76 77 def _NlAttr(self, nla_type, data): 78 assert isinstance(data, bytes) 79 datalen = len(data) 80 # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long. 81 padding = b"\x00" * util.GetPadLength(NLA_ALIGNTO, datalen) 82 nla_len = datalen + len(NLAttr) 83 return NLAttr((nla_len, nla_type)).Pack() + data + padding 84 85 def _NlAttrIPAddress(self, nla_type, family, address): 86 return self._NlAttr(nla_type, socket.inet_pton(family, address)) 87 88 def _NlAttrStr(self, nla_type, value): 89 value = value + "\x00" 90 return self._NlAttr(nla_type, value.encode("UTF-8")) 91 92 def _NlAttrU32(self, nla_type, value): 93 return self._NlAttr(nla_type, struct.pack("=I", value)) 94 95 @staticmethod 96 def _GetConstantName(module, value, prefix): 97 98 def FirstMatching(name, prefixlist): 99 for prefix in prefixlist: 100 if name.startswith(prefix): 101 return prefix 102 return None 103 104 thismodule = sys.modules[module] 105 constant_prefixes = getattr(thismodule, "CONSTANT_PREFIXES", []) 106 for name in dir(thismodule): 107 if value != getattr(thismodule, name) or not name.isupper(): 108 continue 109 # If the module explicitly specifies prefixes, only return this name if 110 # the passed-in prefix is the longest prefix that matches the name. 111 # This ensures, for example, that passing in a prefix of "IFA_" and a 112 # value of 1 returns "IFA_ADDRESS" instead of "IFA_F_SECONDARY". 113 # The longest matching prefix is always the first matching prefix because 114 # CONSTANT_PREFIXES must be sorted longest first. 115 if constant_prefixes and prefix != FirstMatching(name, constant_prefixes): 116 continue 117 if name.startswith(prefix): 118 return name 119 return value 120 121 def _Decode(self, command, msg, nla_type, nla_data, nested): 122 """No-op, nonspecific version of decode.""" 123 return nla_type, nla_data 124 125 def _ReadNlAttr(self, data): 126 # Read the nlattr header. 127 nla, data = cstruct.Read(data, NLAttr) 128 129 # Read the data. 130 datalen = nla.nla_len - len(nla) 131 padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen 132 nla_data, data = data[:datalen], data[padded_len:] 133 134 return nla, nla_data, data 135 136 def _ParseAttributes(self, command, msg, data, nested): 137 """Parses and decodes netlink attributes. 138 139 Takes a block of NLAttr data structures, decodes them using Decode, and 140 returns the result in a dict keyed by attribute number. 141 142 Args: 143 command: An integer, the rtnetlink command being carried out. 144 msg: A Struct, the type of the data after the netlink header. 145 data: A byte string containing a sequence of NLAttr data structures. 146 nested: A list, outermost first, of each of the attributes the NLAttrs are 147 nested inside. Empty for non-nested attributes. 148 149 Returns: 150 A dictionary mapping attribute types (integers) to decoded values. 151 152 Raises: 153 ValueError: There was a duplicate attribute type. 154 """ 155 attributes = {} 156 while data: 157 nla, nla_data, data = self._ReadNlAttr(data) 158 159 # If it's an attribute we know about, try to decode it. 160 nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data, nested) 161 162 if nla_name in attributes and nla_name not in DUP_ATTRS_OK: 163 raise ValueError("Duplicate attribute %s" % nla_name) 164 165 attributes[nla_name] = nla_data 166 if not nested: 167 self._Debug(" %s" % (str((nla_name, nla_data)))) 168 169 return attributes 170 171 def _OpenNetlinkSocket(self, family, groups): 172 sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family) 173 if groups: 174 sock.bind((0, groups)) 175 sock.connect((0, 0)) # The kernel. 176 return sock 177 178 def __init__(self, family, groups=None): 179 # Global sequence number. 180 self.seq = 0 181 self.sock = self._OpenNetlinkSocket(family, groups) 182 self.pid = self.sock.getsockname()[1] 183 184 def close(self): 185 self.sock.close() 186 self.sock = None 187 188 def __del__(self): 189 if self.sock: 190 self.close() 191 192 def MaybeDebugCommand(self, command, flags, data): 193 # Default no-op implementation to be overridden by subclasses. 194 pass 195 196 def _Send(self, msg): 197 # self._Debug(msg.encode("hex")) 198 self.seq += 1 199 self.sock.send(msg) 200 201 def _Recv(self): 202 data = self.sock.recv(self.BUFSIZE) 203 # self._Debug(data.encode("hex")) 204 return data 205 206 def _ExpectDone(self): 207 response = self._Recv() 208 hdr = NLMsgHdr(response) 209 if hdr.type != NLMSG_DONE: 210 raise ValueError("Expected DONE, got type %d" % hdr.type) 211 212 def _ParseAck(self, response): 213 # Find the error code. 214 hdr, data = cstruct.Read(response, NLMsgHdr) 215 if hdr.type == NLMSG_ERROR: 216 error = -NLMsgErr(data).error 217 if error: 218 raise IOError(error, os.strerror(error)) 219 else: 220 raise ValueError("Expected ACK, got type %d" % hdr.type) 221 222 def _ExpectAck(self): 223 response = self._Recv() 224 self._ParseAck(response) 225 226 def _SendNlRequest(self, command, data, flags): 227 """Sends a netlink request and expects an ack.""" 228 length = len(NLMsgHdr) + len(data) 229 nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack() 230 231 self.MaybeDebugCommand(command, flags, nlmsg + data) 232 233 # Send the message. 234 self._Send(nlmsg + data) 235 236 if flags & NLM_F_ACK: 237 self._ExpectAck() 238 239 def _ParseNLMsg(self, data, msgtype): 240 """Parses a Netlink message into a header and a dictionary of attributes.""" 241 nlmsghdr, data = cstruct.Read(data, NLMsgHdr) 242 self._Debug(" %s" % nlmsghdr) 243 244 if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE: 245 print("done") 246 return (None, None), data 247 248 nlmsg, data = cstruct.Read(data, msgtype) 249 self._Debug(" %s" % nlmsg) 250 251 # Parse the attributes in the nlmsg. 252 attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg) 253 attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen], []) 254 data = data[attrlen:] 255 return (nlmsg, attributes), data 256 257 def _GetMsg(self, msgtype): 258 data = self._Recv() 259 if NLMsgHdr(data).type == NLMSG_ERROR: 260 self._ParseAck(data) 261 return self._ParseNLMsg(data, msgtype)[0] 262 263 def _GetMsgList(self, msgtype, data, expect_done): 264 out = [] 265 while data: 266 msg, data = self._ParseNLMsg(data, msgtype) 267 if msg is None: 268 break 269 out.append(msg) 270 if expect_done: 271 self._ExpectDone() 272 return out 273 274 def _Dump(self, command, msg, msgtype, attrs=b""): 275 """Sends a dump request and returns a list of decoded messages. 276 277 Args: 278 command: An integer, the command to run (e.g., RTM_NEWADDR). 279 msg: A struct, the request (e.g., a RTMsg). May be None. 280 msgtype: A cstruct.Struct, the data type to parse the dump results as. 281 attrs: A string, the raw bytes of any request attributes to include. 282 283 Returns: 284 A list of (msg, attrs) tuples where msg is of type msgtype and attrs is 285 a dict of attributes. 286 """ 287 # Create a netlink dump request containing the msg. 288 flags = NLM_F_DUMP | NLM_F_REQUEST 289 msg = b"" if msg is None else msg.Pack() 290 length = len(NLMsgHdr) + len(msg) + len(attrs) 291 nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid)) 292 293 # Send the request. 294 request = nlmsghdr.Pack() + msg + attrs 295 self.MaybeDebugCommand(command, flags, request) 296 self._Send(request) 297 298 # Keep reading netlink messages until we get a NLMSG_DONE. 299 out = [] 300 while True: 301 data = self._Recv() 302 response_type = NLMsgHdr(data).type 303 if response_type == NLMSG_DONE: 304 break 305 elif response_type == NLMSG_ERROR: 306 # Likely means that the kernel didn't like our dump request. 307 # Parse the error and throw an exception. 308 self._ParseAck(data) 309 out.extend(self._GetMsgList(msgtype, data, False)) 310 311 return out 312