• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Partial Python implementation of iproute functionality."""
18
19# pylint: disable=g-bad-todo
20
21import os
22import socket
23import struct
24import sys
25
26import cstruct
27
28
29# Request constants.
30NLM_F_REQUEST = 1
31NLM_F_ACK = 4
32NLM_F_REPLACE = 0x100
33NLM_F_EXCL = 0x200
34NLM_F_CREATE = 0x400
35NLM_F_DUMP = 0x300
36
37# Message types.
38NLMSG_ERROR = 2
39NLMSG_DONE = 3
40
41# Data structure formats.
42# These aren't constants, they're classes. So, pylint: disable=invalid-name
43NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
44NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
45NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
46
47# Alignment / padding.
48NLA_ALIGNTO = 4
49
50
51def PaddedLength(length):
52  # TODO: This padding is probably overly simplistic.
53  return NLA_ALIGNTO * ((length / NLA_ALIGNTO) + (length % NLA_ALIGNTO != 0))
54
55
56class NetlinkSocket(object):
57  """A basic netlink socket object."""
58
59  BUFSIZE = 65536
60  DEBUG = False
61  # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
62  NL_DEBUG = []
63
64  def _Debug(self, s):
65    if self.DEBUG:
66      print s
67
68  def _NlAttr(self, nla_type, data):
69    datalen = len(data)
70    # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
71    padding = "\x00" * (PaddedLength(datalen) - datalen)
72    nla_len = datalen + len(NLAttr)
73    return NLAttr((nla_len, nla_type)).Pack() + data + padding
74
75  def _NlAttrU32(self, nla_type, value):
76    return self._NlAttr(nla_type, struct.pack("=I", value))
77
78  def _GetConstantName(self, module, value, prefix):
79    thismodule = sys.modules[module]
80    for name in dir(thismodule):
81      if name.startswith("INET_DIAG_BC"):
82        continue
83      if (name.startswith(prefix) and
84          not name.startswith(prefix + "F_") and
85          name.isupper() and getattr(thismodule, name) == value):
86          return name
87    return value
88
89  def _Decode(self, command, msg, nla_type, nla_data):
90    """No-op, nonspecific version of decode."""
91    return nla_type, nla_data
92
93  def _ParseAttributes(self, command, msg, data):
94    """Parses and decodes netlink attributes.
95
96    Takes a block of NLAttr data structures, decodes them using Decode, and
97    returns the result in a dict keyed by attribute number.
98
99    Args:
100      command: An integer, the rtnetlink command being carried out.
101      msg: A Struct, the type of the data after the netlink header.
102      data: A byte string containing a sequence of NLAttr data structures.
103
104    Returns:
105      A dictionary mapping attribute types (integers) to decoded values.
106
107    Raises:
108      ValueError: There was a duplicate attribute type.
109    """
110    attributes = {}
111    while data:
112      # Read the nlattr header.
113      nla, data = cstruct.Read(data, NLAttr)
114
115      # Read the data.
116      datalen = nla.nla_len - len(nla)
117      padded_len = PaddedLength(nla.nla_len) - len(nla)
118      nla_data, data = data[:datalen], data[padded_len:]
119
120      # If it's an attribute we know about, try to decode it.
121      nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
122
123      # We only support unique attributes for now, except for INET_DIAG_NONE,
124      # which can appear more than once but doesn't seem to contain any data.
125      if nla_name in attributes and nla_name != "INET_DIAG_NONE":
126        raise ValueError("Duplicate attribute %s" % nla_name)
127
128      attributes[nla_name] = nla_data
129      self._Debug("      %s" % str((nla_name, nla_data)))
130
131    return attributes
132
133  def __init__(self):
134    # Global sequence number.
135    self.seq = 0
136    self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.FAMILY)
137    self.sock.connect((0, 0))  # The kernel.
138    self.pid = self.sock.getsockname()[1]
139
140  def MaybeDebugCommand(self, command, flags, data):
141    # Default no-op implementation to be overridden by subclasses.
142    pass
143
144  def _Send(self, msg):
145    # self._Debug(msg.encode("hex"))
146    self.seq += 1
147    self.sock.send(msg)
148
149  def _Recv(self):
150    data = self.sock.recv(self.BUFSIZE)
151    # self._Debug(data.encode("hex"))
152    return data
153
154  def _ExpectDone(self):
155    response = self._Recv()
156    hdr = NLMsgHdr(response)
157    if hdr.type != NLMSG_DONE:
158      raise ValueError("Expected DONE, got type %d" % hdr.type)
159
160  def _ParseAck(self, response):
161    # Find the error code.
162    hdr, data = cstruct.Read(response, NLMsgHdr)
163    if hdr.type == NLMSG_ERROR:
164      error = NLMsgErr(data).error
165      if error:
166        raise IOError(error, os.strerror(-error))
167    else:
168      raise ValueError("Expected ACK, got type %d" % hdr.type)
169
170  def _ExpectAck(self):
171    response = self._Recv()
172    self._ParseAck(response)
173
174  def _SendNlRequest(self, command, data, flags):
175    """Sends a netlink request and expects an ack."""
176    length = len(NLMsgHdr) + len(data)
177    nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
178
179    self.MaybeDebugCommand(command, flags, nlmsg + data)
180
181    # Send the message.
182    self._Send(nlmsg + data)
183
184    if flags & NLM_F_ACK:
185      self._ExpectAck()
186
187  def _ParseNLMsg(self, data, msgtype):
188    """Parses a Netlink message into a header and a dictionary of attributes."""
189    nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
190    self._Debug("  %s" % nlmsghdr)
191
192    if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
193      print "done"
194      return (None, None), data
195
196    nlmsg, data = cstruct.Read(data, msgtype)
197    self._Debug("    %s" % nlmsg)
198
199    # Parse the attributes in the nlmsg.
200    attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
201    attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen])
202    data = data[attrlen:]
203    return (nlmsg, attributes), data
204
205  def _GetMsg(self, msgtype):
206    data = self._Recv()
207    if NLMsgHdr(data).type == NLMSG_ERROR:
208      self._ParseAck(data)
209    return self._ParseNLMsg(data, msgtype)[0]
210
211  def _GetMsgList(self, msgtype, data, expect_done):
212    out = []
213    while data:
214      msg, data = self._ParseNLMsg(data, msgtype)
215      if msg is None:
216        break
217      out.append(msg)
218    if expect_done:
219      self._ExpectDone()
220    return out
221
222  def _Dump(self, command, msg, msgtype, attrs):
223    """Sends a dump request and returns a list of decoded messages.
224
225    Args:
226      command: An integer, the command to run (e.g., RTM_NEWADDR).
227      msg: A struct, the request (e.g., a RTMsg). May be None.
228      msgtype: A cstruct.Struct, the data type to parse the dump results as.
229      attrs: A string, the raw bytes of any request attributes to include.
230
231    Returns:
232      A list of (msg, attrs) tuples where msg is of type msgtype and attrs is
233      a dict of attributes.
234    """
235    # Create a netlink dump request containing the msg.
236    flags = NLM_F_DUMP | NLM_F_REQUEST
237    msg = "" if msg is None else msg.Pack()
238    length = len(NLMsgHdr) + len(msg) + len(attrs)
239    nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
240
241    # Send the request.
242    request = nlmsghdr.Pack() + msg + attrs
243    self.MaybeDebugCommand(command, flags, request)
244    self._Send(request)
245
246    # Keep reading netlink messages until we get a NLMSG_DONE.
247    out = []
248    while True:
249      data = self._Recv()
250      response_type = NLMsgHdr(data).type
251      if response_type == NLMSG_DONE:
252        break
253      elif response_type == NLMSG_ERROR:
254        # Likely means that the kernel didn't like our dump request.
255        # Parse the error and throw an exception.
256        self._ParseAck(data)
257      out.extend(self._GetMsgList(msgtype, data, False))
258
259    return out
260