• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright 2014 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Partial Python implementation of iproute functionality."""
18
19# pylint: disable=g-bad-todo
20
21import os
22import socket
23import struct
24import sys
25
26import cstruct
27import util
28
29### Base netlink constants. See include/uapi/linux/netlink.h.
30NETLINK_ROUTE = 0
31NETLINK_SOCK_DIAG = 4
32NETLINK_XFRM = 6
33NETLINK_GENERIC = 16
34
35# Request constants.
36NLM_F_REQUEST = 1
37NLM_F_ACK = 4
38NLM_F_REPLACE = 0x100
39NLM_F_EXCL = 0x200
40NLM_F_CREATE = 0x400
41NLM_F_DUMP = 0x300
42
43# Message types.
44NLMSG_ERROR = 2
45NLMSG_DONE = 3
46
47# Data structure formats.
48# These aren't constants, they're classes. So, pylint: disable=invalid-name
49NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
50NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
51NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
52
53# Alignment / padding.
54NLA_ALIGNTO = 4
55
56# List of attributes that can appear more than once in a given netlink message.
57# These can appear more than once but don't seem to contain any data.
58DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"]
59
60
61def MakeConstantPrefixes(prefixes):
62  return sorted(prefixes, key=len, reverse=True)
63
64
65class NetlinkSocket(object):
66  """A basic netlink socket object."""
67
68  BUFSIZE = 65536
69  DEBUG = False
70  # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
71  NL_DEBUG = []
72
73  def _Debug(self, s):
74    if self.DEBUG:
75      print(s)
76
77  def _NlAttr(self, nla_type, data):
78    assert isinstance(data, bytes)
79    datalen = len(data)
80    # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
81    padding = b"\x00" * util.GetPadLength(NLA_ALIGNTO, datalen)
82    nla_len = datalen + len(NLAttr)
83    return NLAttr((nla_len, nla_type)).Pack() + data + padding
84
85  def _NlAttrIPAddress(self, nla_type, family, address):
86    return self._NlAttr(nla_type, socket.inet_pton(family, address))
87
88  def _NlAttrStr(self, nla_type, value):
89    value = value + "\x00"
90    return self._NlAttr(nla_type, value.encode("UTF-8"))
91
92  def _NlAttrU32(self, nla_type, value):
93    return self._NlAttr(nla_type, struct.pack("=I", value))
94
95  @staticmethod
96  def _GetConstantName(module, value, prefix):
97
98    def FirstMatching(name, prefixlist):
99      for prefix in prefixlist:
100        if name.startswith(prefix):
101         return prefix
102      return None
103
104    thismodule = sys.modules[module]
105    constant_prefixes = getattr(thismodule, "CONSTANT_PREFIXES", [])
106    for name in dir(thismodule):
107      if value != getattr(thismodule, name) or not name.isupper():
108        continue
109      # If the module explicitly specifies prefixes, only return this name if
110      # the passed-in prefix is the longest prefix that matches the name.
111      # This ensures, for example, that passing in a prefix of "IFA_" and a
112      # value of 1 returns "IFA_ADDRESS" instead of "IFA_F_SECONDARY".
113      # The longest matching prefix is always the first matching prefix because
114      # CONSTANT_PREFIXES must be sorted longest first.
115      if constant_prefixes and prefix != FirstMatching(name, constant_prefixes):
116        continue
117      if name.startswith(prefix):
118        return name
119    return value
120
121  def _Decode(self, command, msg, nla_type, nla_data, nested):
122    """No-op, nonspecific version of decode."""
123    return nla_type, nla_data
124
125  def _ReadNlAttr(self, data):
126    # Read the nlattr header.
127    nla, data = cstruct.Read(data, NLAttr)
128
129    # Read the data.
130    datalen = nla.nla_len - len(nla)
131    padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen
132    nla_data, data = data[:datalen], data[padded_len:]
133
134    return nla, nla_data, data
135
136  def _ParseAttributes(self, command, msg, data, nested):
137    """Parses and decodes netlink attributes.
138
139    Takes a block of NLAttr data structures, decodes them using Decode, and
140    returns the result in a dict keyed by attribute number.
141
142    Args:
143      command: An integer, the rtnetlink command being carried out.
144      msg: A Struct, the type of the data after the netlink header.
145      data: A byte string containing a sequence of NLAttr data structures.
146      nested: A list, outermost first, of each of the attributes the NLAttrs are
147              nested inside. Empty for non-nested attributes.
148
149    Returns:
150      A dictionary mapping attribute types (integers) to decoded values.
151
152    Raises:
153      ValueError: There was a duplicate attribute type.
154    """
155    attributes = {}
156    while data:
157      nla, nla_data, data = self._ReadNlAttr(data)
158
159      # If it's an attribute we know about, try to decode it.
160      nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data, nested)
161
162      if nla_name in attributes and nla_name not in DUP_ATTRS_OK:
163        raise ValueError("Duplicate attribute %s" % nla_name)
164
165      attributes[nla_name] = nla_data
166      if not nested:
167        self._Debug("      %s" % (str((nla_name, nla_data))))
168
169    return attributes
170
171  def _OpenNetlinkSocket(self, family, groups):
172    sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family)
173    if groups:
174      sock.bind((0,  groups))
175    sock.connect((0, 0))  # The kernel.
176    return sock
177
178  def __init__(self, family, groups=None):
179    # Global sequence number.
180    self.seq = 0
181    self.sock = self._OpenNetlinkSocket(family, groups)
182    self.pid = self.sock.getsockname()[1]
183
184  def close(self):
185    self.sock.close()
186    self.sock = None
187
188  def __del__(self):
189    if self.sock:
190      self.close()
191
192  def MaybeDebugCommand(self, command, flags, data):
193    # Default no-op implementation to be overridden by subclasses.
194    pass
195
196  def _Send(self, msg):
197    # self._Debug(msg.encode("hex"))
198    self.seq += 1
199    self.sock.send(msg)
200
201  def _Recv(self):
202    data = self.sock.recv(self.BUFSIZE)
203    # self._Debug(data.encode("hex"))
204    return data
205
206  def _ExpectDone(self):
207    response = self._Recv()
208    hdr = NLMsgHdr(response)
209    if hdr.type != NLMSG_DONE:
210      raise ValueError("Expected DONE, got type %d" % hdr.type)
211
212  def _ParseAck(self, response):
213    # Find the error code.
214    hdr, data = cstruct.Read(response, NLMsgHdr)
215    if hdr.type == NLMSG_ERROR:
216      error = -NLMsgErr(data).error
217      if error:
218        raise IOError(error, os.strerror(error))
219    else:
220      raise ValueError("Expected ACK, got type %d" % hdr.type)
221
222  def _ExpectAck(self):
223    response = self._Recv()
224    self._ParseAck(response)
225
226  def _SendNlRequest(self, command, data, flags):
227    """Sends a netlink request and expects an ack."""
228    length = len(NLMsgHdr) + len(data)
229    nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
230
231    self.MaybeDebugCommand(command, flags, nlmsg + data)
232
233    # Send the message.
234    self._Send(nlmsg + data)
235
236    if flags & NLM_F_ACK:
237      self._ExpectAck()
238
239  def _ParseNLMsg(self, data, msgtype):
240    """Parses a Netlink message into a header and a dictionary of attributes."""
241    nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
242    self._Debug("  %s" % nlmsghdr)
243
244    if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
245      print("done")
246      return (None, None), data
247
248    nlmsg, data = cstruct.Read(data, msgtype)
249    self._Debug("    %s" % nlmsg)
250
251    # Parse the attributes in the nlmsg.
252    attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
253    attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen], [])
254    data = data[attrlen:]
255    return (nlmsg, attributes), data
256
257  def _GetMsg(self, msgtype):
258    data = self._Recv()
259    if NLMsgHdr(data).type == NLMSG_ERROR:
260      self._ParseAck(data)
261    return self._ParseNLMsg(data, msgtype)[0]
262
263  def _GetMsgList(self, msgtype, data, expect_done):
264    out = []
265    while data:
266      msg, data = self._ParseNLMsg(data, msgtype)
267      if msg is None:
268        break
269      out.append(msg)
270    if expect_done:
271      self._ExpectDone()
272    return out
273
274  def _Dump(self, command, msg, msgtype, attrs=b""):
275    """Sends a dump request and returns a list of decoded messages.
276
277    Args:
278      command: An integer, the command to run (e.g., RTM_NEWADDR).
279      msg: A struct, the request (e.g., a RTMsg). May be None.
280      msgtype: A cstruct.Struct, the data type to parse the dump results as.
281      attrs: A string, the raw bytes of any request attributes to include.
282
283    Returns:
284      A list of (msg, attrs) tuples where msg is of type msgtype and attrs is
285      a dict of attributes.
286    """
287    # Create a netlink dump request containing the msg.
288    flags = NLM_F_DUMP | NLM_F_REQUEST
289    msg = b"" if msg is None else msg.Pack()
290    length = len(NLMsgHdr) + len(msg) + len(attrs)
291    nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
292
293    # Send the request.
294    request = nlmsghdr.Pack() + msg + attrs
295    self.MaybeDebugCommand(command, flags, request)
296    self._Send(request)
297
298    # Keep reading netlink messages until we get a NLMSG_DONE.
299    out = []
300    while True:
301      data = self._Recv()
302      response_type = NLMsgHdr(data).type
303      if response_type == NLMSG_DONE:
304        break
305      elif response_type == NLMSG_ERROR:
306        # Likely means that the kernel didn't like our dump request.
307        # Parse the error and throw an exception.
308        self._ParseAck(data)
309      out.extend(self._GetMsgList(msgtype, data, False))
310
311    return out
312