1#!/usr/bin/python 2# 3# Copyright 2014 The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Partial Python implementation of iproute functionality.""" 18 19# pylint: disable=g-bad-todo 20 21import os 22import socket 23import struct 24import sys 25 26import cstruct 27 28 29# Request constants. 30NLM_F_REQUEST = 1 31NLM_F_ACK = 4 32NLM_F_REPLACE = 0x100 33NLM_F_EXCL = 0x200 34NLM_F_CREATE = 0x400 35NLM_F_DUMP = 0x300 36 37# Message types. 38NLMSG_ERROR = 2 39NLMSG_DONE = 3 40 41# Data structure formats. 42# These aren't constants, they're classes. So, pylint: disable=invalid-name 43NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid") 44NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error") 45NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type") 46 47# Alignment / padding. 48NLA_ALIGNTO = 4 49 50 51def PaddedLength(length): 52 # TODO: This padding is probably overly simplistic. 53 return NLA_ALIGNTO * ((length / NLA_ALIGNTO) + (length % NLA_ALIGNTO != 0)) 54 55 56class NetlinkSocket(object): 57 """A basic netlink socket object.""" 58 59 BUFSIZE = 65536 60 DEBUG = False 61 # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"] 62 NL_DEBUG = [] 63 64 def _Debug(self, s): 65 if self.DEBUG: 66 print s 67 68 def _NlAttr(self, nla_type, data): 69 datalen = len(data) 70 # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long. 71 padding = "\x00" * (PaddedLength(datalen) - datalen) 72 nla_len = datalen + len(NLAttr) 73 return NLAttr((nla_len, nla_type)).Pack() + data + padding 74 75 def _NlAttrU32(self, nla_type, value): 76 return self._NlAttr(nla_type, struct.pack("=I", value)) 77 78 def _GetConstantName(self, module, value, prefix): 79 thismodule = sys.modules[module] 80 for name in dir(thismodule): 81 if name.startswith("INET_DIAG_BC"): 82 continue 83 if (name.startswith(prefix) and 84 not name.startswith(prefix + "F_") and 85 name.isupper() and getattr(thismodule, name) == value): 86 return name 87 return value 88 89 def _Decode(self, command, msg, nla_type, nla_data): 90 """No-op, nonspecific version of decode.""" 91 return nla_type, nla_data 92 93 def _ParseAttributes(self, command, msg, data): 94 """Parses and decodes netlink attributes. 95 96 Takes a block of NLAttr data structures, decodes them using Decode, and 97 returns the result in a dict keyed by attribute number. 98 99 Args: 100 command: An integer, the rtnetlink command being carried out. 101 msg: A Struct, the type of the data after the netlink header. 102 data: A byte string containing a sequence of NLAttr data structures. 103 104 Returns: 105 A dictionary mapping attribute types (integers) to decoded values. 106 107 Raises: 108 ValueError: There was a duplicate attribute type. 109 """ 110 attributes = {} 111 while data: 112 # Read the nlattr header. 113 nla, data = cstruct.Read(data, NLAttr) 114 115 # Read the data. 116 datalen = nla.nla_len - len(nla) 117 padded_len = PaddedLength(nla.nla_len) - len(nla) 118 nla_data, data = data[:datalen], data[padded_len:] 119 120 # If it's an attribute we know about, try to decode it. 121 nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data) 122 123 # We only support unique attributes for now, except for INET_DIAG_NONE, 124 # which can appear more than once but doesn't seem to contain any data. 125 if nla_name in attributes and nla_name != "INET_DIAG_NONE": 126 raise ValueError("Duplicate attribute %s" % nla_name) 127 128 attributes[nla_name] = nla_data 129 self._Debug(" %s" % str((nla_name, nla_data))) 130 131 return attributes 132 133 def __init__(self): 134 # Global sequence number. 135 self.seq = 0 136 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.FAMILY) 137 self.sock.connect((0, 0)) # The kernel. 138 self.pid = self.sock.getsockname()[1] 139 140 def MaybeDebugCommand(self, command, flags, data): 141 # Default no-op implementation to be overridden by subclasses. 142 pass 143 144 def _Send(self, msg): 145 # self._Debug(msg.encode("hex")) 146 self.seq += 1 147 self.sock.send(msg) 148 149 def _Recv(self): 150 data = self.sock.recv(self.BUFSIZE) 151 # self._Debug(data.encode("hex")) 152 return data 153 154 def _ExpectDone(self): 155 response = self._Recv() 156 hdr = NLMsgHdr(response) 157 if hdr.type != NLMSG_DONE: 158 raise ValueError("Expected DONE, got type %d" % hdr.type) 159 160 def _ParseAck(self, response): 161 # Find the error code. 162 hdr, data = cstruct.Read(response, NLMsgHdr) 163 if hdr.type == NLMSG_ERROR: 164 error = NLMsgErr(data).error 165 if error: 166 raise IOError(error, os.strerror(-error)) 167 else: 168 raise ValueError("Expected ACK, got type %d" % hdr.type) 169 170 def _ExpectAck(self): 171 response = self._Recv() 172 self._ParseAck(response) 173 174 def _SendNlRequest(self, command, data, flags): 175 """Sends a netlink request and expects an ack.""" 176 length = len(NLMsgHdr) + len(data) 177 nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack() 178 179 self.MaybeDebugCommand(command, flags, nlmsg + data) 180 181 # Send the message. 182 self._Send(nlmsg + data) 183 184 if flags & NLM_F_ACK: 185 self._ExpectAck() 186 187 def _ParseNLMsg(self, data, msgtype): 188 """Parses a Netlink message into a header and a dictionary of attributes.""" 189 nlmsghdr, data = cstruct.Read(data, NLMsgHdr) 190 self._Debug(" %s" % nlmsghdr) 191 192 if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE: 193 print "done" 194 return (None, None), data 195 196 nlmsg, data = cstruct.Read(data, msgtype) 197 self._Debug(" %s" % nlmsg) 198 199 # Parse the attributes in the nlmsg. 200 attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg) 201 attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen]) 202 data = data[attrlen:] 203 return (nlmsg, attributes), data 204 205 def _GetMsg(self, msgtype): 206 data = self._Recv() 207 if NLMsgHdr(data).type == NLMSG_ERROR: 208 self._ParseAck(data) 209 return self._ParseNLMsg(data, msgtype)[0] 210 211 def _GetMsgList(self, msgtype, data, expect_done): 212 out = [] 213 while data: 214 msg, data = self._ParseNLMsg(data, msgtype) 215 if msg is None: 216 break 217 out.append(msg) 218 if expect_done: 219 self._ExpectDone() 220 return out 221 222 def _Dump(self, command, msg, msgtype, attrs): 223 """Sends a dump request and returns a list of decoded messages. 224 225 Args: 226 command: An integer, the command to run (e.g., RTM_NEWADDR). 227 msg: A struct, the request (e.g., a RTMsg). May be None. 228 msgtype: A cstruct.Struct, the data type to parse the dump results as. 229 attrs: A string, the raw bytes of any request attributes to include. 230 231 Returns: 232 A list of (msg, attrs) tuples where msg is of type msgtype and attrs is 233 a dict of attributes. 234 """ 235 # Create a netlink dump request containing the msg. 236 flags = NLM_F_DUMP | NLM_F_REQUEST 237 msg = "" if msg is None else msg.Pack() 238 length = len(NLMsgHdr) + len(msg) + len(attrs) 239 nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid)) 240 241 # Send the request. 242 request = nlmsghdr.Pack() + msg + attrs 243 self.MaybeDebugCommand(command, flags, request) 244 self._Send(request) 245 246 # Keep reading netlink messages until we get a NLMSG_DONE. 247 out = [] 248 while True: 249 data = self._Recv() 250 response_type = NLMsgHdr(data).type 251 if response_type == NLMSG_DONE: 252 break 253 elif response_type == NLMSG_ERROR: 254 # Likely means that the kernel didn't like our dump request. 255 # Parse the error and throw an exception. 256 self._ParseAck(data) 257 out.extend(self._GetMsgList(msgtype, data, False)) 258 259 return out 260