• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module, except:
21... # - S: Nested Struct.
22... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
23... #      trailing NULL blocks at the end.
24... import cstruct
25>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
26>>>
27>>>
28>>> # Create instances from a tuple of values, raw bytes, zero-initialized, or
29>>> # using keywords.
30... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
31>>> print n1
32NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
33>>>
34>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
35...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
36>>> print n2
37NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
38>>>
39>>> n3 = netlink.NLMsgHdr() # Zero-initialized
40>>> print n3
41NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0)
42>>>
43>>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized
44>>> print n4
45NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0)
46>>>
47>>> # Serialize to raw bytes.
48... print n1.Pack().encode("hex")
492c0000002000020000000000eb010000
50>>>
51>>> # Parse the beginning of a byte stream as a struct, and return the struct
52... # and the remainder of the stream for further reading.
53... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
54...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
55...         "more data")
56>>> cstruct.Read(data, NLMsgHdr)
57(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
58>>>
59>>> # Structs can contain one or more nested structs. The nested struct types
60... # are specified in a list as an optional last argument. Nested structs may
61... # contain nested structs.
62... S = cstruct.Struct("S", "=BI", "byte1 int2")
63>>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
64>>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
65>>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
66>>> nn.n3.s2.int2 = 5
67>>>
68"""
69
70import ctypes
71import string
72import struct
73
74
75def CalcSize(fmt):
76  if "A" in fmt:
77    fmt = fmt.replace("A", "s")
78  return struct.calcsize(fmt)
79
80def CalcNumElements(fmt):
81  prevlen = len(fmt)
82  fmt = fmt.replace("S", "")
83  numstructs = prevlen - len(fmt)
84  size = CalcSize(fmt)
85  elements = struct.unpack(fmt, "\x00" * size)
86  return len(elements) + numstructs
87
88
89def Struct(name, fmt, fieldnames, substructs={}):
90  """Function that returns struct classes."""
91
92  class Meta(type):
93
94    def __len__(cls):
95      return cls._length
96
97    def __init__(cls, unused_name, unused_bases, namespace):
98      # Make the class object have the name that's passed in.
99      type.__init__(cls, namespace["_name"], unused_bases, namespace)
100
101  class CStruct(object):
102    """Class representing a C-like structure."""
103
104    __metaclass__ = Meta
105
106    # Name of the struct.
107    _name = name
108    # List of field names.
109    _fieldnames = fieldnames
110    # Dict mapping field indices to nested struct classes.
111    _nested = {}
112    # List of string fields that are ASCII strings.
113    _asciiz = set()
114
115    _fieldnames = _fieldnames.split(" ")
116
117    # Parse fmt into _format, converting any S format characters to "XXs",
118    # where XX is the length of the struct type's packed representation.
119    _format = ""
120    laststructindex = 0
121    for i in xrange(len(fmt)):
122      if fmt[i] == "S":
123        # Nested struct. Record the index in our struct it should go into.
124        index = CalcNumElements(fmt[:i])
125        _nested[index] = substructs[laststructindex]
126        laststructindex += 1
127        _format += "%ds" % len(_nested[index])
128      elif fmt[i] == "A":
129        # Null-terminated ASCII string.
130        index = CalcNumElements(fmt[:i])
131        _asciiz.add(index)
132        _format += "s"
133      else:
134        # Standard struct format character.
135        _format += fmt[i]
136
137    _length = CalcSize(_format)
138
139    offset_list = [0]
140    last_offset = 0
141    for i in xrange(len(_format)):
142      offset = CalcSize(_format[:i])
143      if offset > last_offset:
144        last_offset = offset
145        offset_list.append(offset)
146
147    # A dictionary that maps field names to their offsets in the struct.
148    _offsets = dict(zip(_fieldnames, offset_list))
149
150    def _SetValues(self, values):
151      # Replace self._values with the given list. We can't do direct assignment
152      # because of the __setattr__ overload on this class.
153      super(CStruct, self).__setattr__("_values", list(values))
154
155    def _Parse(self, data):
156      data = data[:self._length]
157      values = list(struct.unpack(self._format, data))
158      for index, value in enumerate(values):
159        if isinstance(value, str) and index in self._nested:
160          values[index] = self._nested[index](value)
161      self._SetValues(values)
162
163    def __init__(self, tuple_or_bytes=None, **kwargs):
164      """Construct an instance of this Struct.
165
166      1. With no args, the whole struct is zero-initialized.
167      2. With keyword args, the matching fields are populated; rest are zeroed.
168      3. With one tuple as the arg, the fields are assigned based on position.
169      4. With one string arg, the Struct is parsed from bytes.
170      """
171      if tuple_or_bytes and kwargs:
172        raise TypeError(
173            "%s: cannot specify both a tuple and keyword args" % self._name)
174
175      if tuple_or_bytes is None:
176        # Default construct from null bytes.
177        self._Parse("\x00" * len(self))
178        # If any keywords were supplied, set those fields.
179        for k, v in kwargs.iteritems():
180          setattr(self, k, v)
181      elif isinstance(tuple_or_bytes, str):
182        # Initializing from a string.
183        if len(tuple_or_bytes) < self._length:
184          raise TypeError("%s requires string of length %d, got %d" %
185                          (self._name, self._length, len(tuple_or_bytes)))
186        self._Parse(tuple_or_bytes)
187      else:
188        # Initializing from a tuple.
189        if len(tuple_or_bytes) != len(self._fieldnames):
190          raise TypeError("%s has exactly %d fieldnames (%d given)" %
191                          (self._name, len(self._fieldnames),
192                           len(tuple_or_bytes)))
193        self._SetValues(tuple_or_bytes)
194
195    def _FieldIndex(self, attr):
196      try:
197        return self._fieldnames.index(attr)
198      except ValueError:
199        raise AttributeError("'%s' has no attribute '%s'" %
200                             (self._name, attr))
201
202    def __getattr__(self, name):
203      return self._values[self._FieldIndex(name)]
204
205    def __setattr__(self, name, value):
206      # TODO: check value type against self._format and throw here, or else
207      # callers get an unhelpful exception when they call Pack().
208      self._values[self._FieldIndex(name)] = value
209
210    def offset(self, name):
211      if "." in name:
212        raise NotImplementedError("offset() on nested field")
213      return self._offsets[name]
214
215    @classmethod
216    def __len__(cls):
217      return cls._length
218
219    def __ne__(self, other):
220      return not self.__eq__(other)
221
222    def __eq__(self, other):
223      return (isinstance(other, self.__class__) and
224              self._name == other._name and
225              self._fieldnames == other._fieldnames and
226              self._values == other._values)
227
228    @staticmethod
229    def _MaybePackStruct(value):
230      if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
231        return value.Pack()
232      else:
233        return value
234
235    def Pack(self):
236      values = [self._MaybePackStruct(v) for v in self._values]
237      return struct.pack(self._format, *values)
238
239    def __str__(self):
240      def FieldDesc(index, name, value):
241        if isinstance(value, str):
242          if index in self._asciiz:
243            value = value.rstrip("\x00")
244          elif any(c not in string.printable for c in value):
245            value = value.encode("hex")
246        return "%s=%s" % (name, value)
247
248      descriptions = [
249          FieldDesc(i, n, v) for i, (n, v) in
250          enumerate(zip(self._fieldnames, self._values))]
251
252      return "%s(%s)" % (self._name, ", ".join(descriptions))
253
254    def __repr__(self):
255      return str(self)
256
257    def CPointer(self):
258      """Returns a C pointer to the serialized structure."""
259      buf = ctypes.create_string_buffer(self.Pack())
260      # Store the C buffer in the object so it doesn't get garbage collected.
261      super(CStruct, self).__setattr__("_buffer", buf)
262      return ctypes.addressof(self._buffer)
263
264  return CStruct
265
266
267def Read(data, struct_type):
268  length = len(struct_type)
269  return struct_type(data), data[length:]
270