1#!/usr/bin/python 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of iproute functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import os 22import socket 23import struct 24import sys 25 26import cstruct 27import util 28 29### Base netlink constants. See include/uapi/linux/netlink.h. 30NETLINK_ROUTE = 0 31NETLINK_SOCK_DIAG = 4 32NETLINK_XFRM = 6 33NETLINK_GENERIC = 16 34 35# Request constants. 36NLM_F_REQUEST = 1 37NLM_F_ACK = 4 38NLM_F_REPLACE = 0x100 39NLM_F_EXCL = 0x200 40NLM_F_CREATE = 0x400 41NLM_F_DUMP = 0x300 42 43# Message types. 44NLMSG_ERROR = 2 45NLMSG_DONE = 3 46 47# Data structure formats. 48# These aren't constants, they're classes. So, pylint: disable=invalid-name 49NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid") 50NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error") 51NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type") 52 53# Alignment / padding. 54NLA_ALIGNTO = 4 55 56# List of attributes that can appear more than once in a given netlink message. 57# These can appear more than once but don't seem to contain any data. 58DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"] 59 60class NetlinkSocket(object): 61 """A basic netlink socket object.""" 62 63 BUFSIZE = 65536 64 DEBUG = False 65 # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"] 66 NL_DEBUG = [] 67 68 def _Debug(self, s): 69 if self.DEBUG: 70 print(s) 71 72 def _NlAttr(self, nla_type, data): 73 datalen = len(data) 74 # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long. 75 padding = "\x00" * util.GetPadLength(NLA_ALIGNTO, datalen) 76 nla_len = datalen + len(NLAttr) 77 return NLAttr((nla_len, nla_type)).Pack() + data + padding 78 79 def _NlAttrIPAddress(self, nla_type, family, address): 80 return self._NlAttr(nla_type, socket.inet_pton(family, address)) 81 82 def _NlAttrStr(self, nla_type, value): 83 value = value + "\x00" 84 return self._NlAttr(nla_type, value.encode("UTF-8")) 85 86 def _NlAttrU32(self, nla_type, value): 87 return self._NlAttr(nla_type, struct.pack("=I", value)) 88 89 def _GetConstantName(self, module, value, prefix): 90 thismodule = sys.modules[module] 91 for name in dir(thismodule): 92 if name.startswith("INET_DIAG_BC"): 93 continue 94 if (name.startswith(prefix) and 95 not name.startswith(prefix + "F_") and 96 name.isupper() and getattr(thismodule, name) == value): 97 return name 98 return value 99 100 def _Decode(self, command, msg, nla_type, nla_data): 101 """No-op, nonspecific version of decode.""" 102 return nla_type, nla_data 103 104 def _ReadNlAttr(self, data): 105 # Read the nlattr header. 106 nla, data = cstruct.Read(data, NLAttr) 107 108 # Read the data. 109 datalen = nla.nla_len - len(nla) 110 padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen 111 nla_data, data = data[:datalen], data[padded_len:] 112 113 return nla, nla_data, data 114 115 def _ParseAttributes(self, command, msg, data, nested=0): 116 """Parses and decodes netlink attributes. 117 118 Takes a block of NLAttr data structures, decodes them using Decode, and 119 returns the result in a dict keyed by attribute number. 120 121 Args: 122 command: An integer, the rtnetlink command being carried out. 123 msg: A Struct, the type of the data after the netlink header. 124 data: A byte string containing a sequence of NLAttr data structures. 125 nested: An integer, how deep we're currently nested. 126 127 Returns: 128 A dictionary mapping attribute types (integers) to decoded values. 129 130 Raises: 131 ValueError: There was a duplicate attribute type. 132 """ 133 attributes = {} 134 while data: 135 nla, nla_data, data = self._ReadNlAttr(data) 136 137 # If it's an attribute we know about, try to decode it. 138 nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data) 139 140 if nla_name in attributes and nla_name not in DUP_ATTRS_OK: 141 raise ValueError("Duplicate attribute %s" % nla_name) 142 143 attributes[nla_name] = nla_data 144 if not nested: 145 self._Debug(" %s" % (str((nla_name, nla_data)))) 146 147 return attributes 148 149 def _OpenNetlinkSocket(self, family, groups): 150 sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family) 151 if groups: 152 sock.bind((0, groups)) 153 sock.connect((0, 0)) # The kernel. 154 return sock 155 156 def __init__(self, family, groups=None): 157 # Global sequence number. 158 self.seq = 0 159 self.sock = self._OpenNetlinkSocket(family, groups) 160 self.pid = self.sock.getsockname()[1] 161 162 def MaybeDebugCommand(self, command, flags, data): 163 # Default no-op implementation to be overridden by subclasses. 164 pass 165 166 def _Send(self, msg): 167 # self._Debug(msg.encode("hex")) 168 self.seq += 1 169 self.sock.send(msg) 170 171 def _Recv(self): 172 data = self.sock.recv(self.BUFSIZE) 173 # self._Debug(data.encode("hex")) 174 return data 175 176 def _ExpectDone(self): 177 response = self._Recv() 178 hdr = NLMsgHdr(response) 179 if hdr.type != NLMSG_DONE: 180 raise ValueError("Expected DONE, got type %d" % hdr.type) 181 182 def _ParseAck(self, response): 183 # Find the error code. 184 hdr, data = cstruct.Read(response, NLMsgHdr) 185 if hdr.type == NLMSG_ERROR: 186 error = NLMsgErr(data).error 187 if error: 188 raise IOError(-error, os.strerror(-error)) 189 else: 190 raise ValueError("Expected ACK, got type %d" % hdr.type) 191 192 def _ExpectAck(self): 193 response = self._Recv() 194 self._ParseAck(response) 195 196 def _SendNlRequest(self, command, data, flags): 197 """Sends a netlink request and expects an ack.""" 198 length = len(NLMsgHdr) + len(data) 199 nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack() 200 201 self.MaybeDebugCommand(command, flags, nlmsg + data) 202 203 # Send the message. 204 self._Send(nlmsg + data) 205 206 if flags & NLM_F_ACK: 207 self._ExpectAck() 208 209 def _ParseNLMsg(self, data, msgtype): 210 """Parses a Netlink message into a header and a dictionary of attributes.""" 211 nlmsghdr, data = cstruct.Read(data, NLMsgHdr) 212 self._Debug(" %s" % nlmsghdr) 213 214 if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE: 215 print("done") 216 return (None, None), data 217 218 nlmsg, data = cstruct.Read(data, msgtype) 219 self._Debug(" %s" % nlmsg) 220 221 # Parse the attributes in the nlmsg. 222 attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg) 223 attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen]) 224 data = data[attrlen:] 225 return (nlmsg, attributes), data 226 227 def _GetMsg(self, msgtype): 228 data = self._Recv() 229 if NLMsgHdr(data).type == NLMSG_ERROR: 230 self._ParseAck(data) 231 return self._ParseNLMsg(data, msgtype)[0] 232 233 def _GetMsgList(self, msgtype, data, expect_done): 234 out = [] 235 while data: 236 msg, data = self._ParseNLMsg(data, msgtype) 237 if msg is None: 238 break 239 out.append(msg) 240 if expect_done: 241 self._ExpectDone() 242 return out 243 244 def _Dump(self, command, msg, msgtype, attrs): 245 """Sends a dump request and returns a list of decoded messages. 246 247 Args: 248 command: An integer, the command to run (e.g., RTM_NEWADDR). 249 msg: A struct, the request (e.g., a RTMsg). May be None. 250 msgtype: A cstruct.Struct, the data type to parse the dump results as. 251 attrs: A string, the raw bytes of any request attributes to include. 252 253 Returns: 254 A list of (msg, attrs) tuples where msg is of type msgtype and attrs is 255 a dict of attributes. 256 """ 257 # Create a netlink dump request containing the msg. 258 flags = NLM_F_DUMP | NLM_F_REQUEST 259 msg = "" if msg is None else msg.Pack() 260 length = len(NLMsgHdr) + len(msg) + len(attrs) 261 nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid)) 262 263 # Send the request. 264 request = nlmsghdr.Pack() + msg + attrs 265 self.MaybeDebugCommand(command, flags, request) 266 self._Send(request) 267 268 # Keep reading netlink messages until we get a NLMSG_DONE. 269 out = [] 270 while True: 271 data = self._Recv() 272 response_type = NLMsgHdr(data).type 273 if response_type == NLMSG_DONE: 274 break 275 elif response_type == NLMSG_ERROR: 276 # Likely means that the kernel didn't like our dump request. 277 # Parse the error and throw an exception. 278 self._ParseAck(data) 279 out.extend(self._GetMsgList(msgtype, data, False)) 280 281 return out 282