• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple module for declaring C-like structures.
16
17Example usage:
18
19>>> # Declare a struct type by specifying name, field formats and field names.
20... # Field formats are the same as those used in the struct module, except:
21... # - S: Nested Struct.
22... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
23... #      trailing NULL blocks at the end.
24... import cstruct
25>>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
26>>>
27>>>
28>>> # Create instances from a tuple of values, raw bytes, zero-initialized, or
29>>> # using keywords.
30... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
31>>> print n1
32NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
33>>>
34>>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
35...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
36>>> print n2
37NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
38>>>
39>>> n3 = netlink.NLMsgHdr() # Zero-initialized
40>>> print n3
41NLMsgHdr(length=0, type=0, flags=0, seq=0, pid=0)
42>>>
43>>> n4 = netlink.NLMsgHdr(length=44, type=33) # Other fields zero-initialized
44>>> print n4
45NLMsgHdr(length=44, type=33, flags=0, seq=0, pid=0)
46>>>
47>>> # Serialize to raw bytes.
48... print n1.Pack().encode("hex")
492c0000002000020000000000eb010000
50>>>
51>>> # Parse the beginning of a byte stream as a struct, and return the struct
52... # and the remainder of the stream for further reading.
53... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
54...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
55...         "more data")
56>>> cstruct.Read(data, NLMsgHdr)
57(NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
58>>>
59>>> # Structs can contain one or more nested structs. The nested struct types
60... # are specified in a list as an optional last argument. Nested structs may
61... # contain nested structs.
62... S = cstruct.Struct("S", "=BI", "byte1 int2")
63>>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
64>>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
65>>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
66>>> nn.n3.s2.int2 = 5
67>>>
68"""
69
70import ctypes
71import string
72import struct
73
74
75def CalcSize(fmt):
76  if "A" in fmt:
77    fmt = fmt.replace("A", "s")
78  return struct.calcsize(fmt)
79
80def CalcNumElements(fmt):
81  prevlen = len(fmt)
82  fmt = fmt.replace("S", "")
83  numstructs = prevlen - len(fmt)
84  size = CalcSize(fmt)
85  elements = struct.unpack(fmt, "\x00" * size)
86  return len(elements) + numstructs
87
88
89def Struct(name, fmt, fieldnames, substructs={}):
90  """Function that returns struct classes."""
91
92  class Meta(type):
93
94    def __len__(cls):
95      return cls._length
96
97    def __init__(cls, unused_name, unused_bases, namespace):
98      # Make the class object have the name that's passed in.
99      type.__init__(cls, namespace["_name"], unused_bases, namespace)
100
101  class CStruct(object):
102    """Class representing a C-like structure."""
103
104    __metaclass__ = Meta
105
106    # Name of the struct.
107    _name = name
108    # List of field names.
109    _fieldnames = fieldnames
110    # Dict mapping field indices to nested struct classes.
111    _nested = {}
112    # List of string fields that are ASCII strings.
113    _asciiz = set()
114
115    _fieldnames = _fieldnames.split(" ")
116
117    # Parse fmt into _format, converting any S format characters to "XXs",
118    # where XX is the length of the struct type's packed representation.
119    _format = ""
120    laststructindex = 0
121    for i in xrange(len(fmt)):
122      if fmt[i] == "S":
123        # Nested struct. Record the index in our struct it should go into.
124        index = CalcNumElements(fmt[:i])
125        _nested[index] = substructs[laststructindex]
126        laststructindex += 1
127        _format += "%ds" % len(_nested[index])
128      elif fmt[i] == "A":
129        # Null-terminated ASCII string.
130        index = CalcNumElements(fmt[:i])
131        _asciiz.add(index)
132        _format += "s"
133      else:
134        # Standard struct format character.
135        _format += fmt[i]
136
137    _length = CalcSize(_format)
138
139    offset_list = [0]
140    last_offset = 0
141    for i in xrange(len(_format)):
142      offset = CalcSize(_format[:i])
143      if offset > last_offset:
144        last_offset = offset
145        offset_list.append(offset)
146
147    # A dictionary that maps field names to their offsets in the struct.
148    _offsets = dict(zip(_fieldnames, offset_list))
149
150    # Check that the number of field names matches the number of fields.
151    numfields = len(struct.unpack(_format, "\x00" * _length))
152    if len(_fieldnames) != numfields:
153      raise ValueError("Invalid cstruct: \"%s\" has %d elements, \"%s\" has %d."
154                       % (fmt, numfields, fieldnames, len(_fieldnames)))
155
156    def _SetValues(self, values):
157      # Replace self._values with the given list. We can't do direct assignment
158      # because of the __setattr__ overload on this class.
159      super(CStruct, self).__setattr__("_values", list(values))
160
161    def _Parse(self, data):
162      data = data[:self._length]
163      values = list(struct.unpack(self._format, data))
164      for index, value in enumerate(values):
165        if isinstance(value, str) and index in self._nested:
166          values[index] = self._nested[index](value)
167      self._SetValues(values)
168
169    def __init__(self, tuple_or_bytes=None, **kwargs):
170      """Construct an instance of this Struct.
171
172      1. With no args, the whole struct is zero-initialized.
173      2. With keyword args, the matching fields are populated; rest are zeroed.
174      3. With one tuple as the arg, the fields are assigned based on position.
175      4. With one string arg, the Struct is parsed from bytes.
176      """
177      if tuple_or_bytes and kwargs:
178        raise TypeError(
179            "%s: cannot specify both a tuple and keyword args" % self._name)
180
181      if tuple_or_bytes is None:
182        # Default construct from null bytes.
183        self._Parse("\x00" * len(self))
184        # If any keywords were supplied, set those fields.
185        for k, v in kwargs.iteritems():
186          setattr(self, k, v)
187      elif isinstance(tuple_or_bytes, str):
188        # Initializing from a string.
189        if len(tuple_or_bytes) < self._length:
190          raise TypeError("%s requires string of length %d, got %d" %
191                          (self._name, self._length, len(tuple_or_bytes)))
192        self._Parse(tuple_or_bytes)
193      else:
194        # Initializing from a tuple.
195        if len(tuple_or_bytes) != len(self._fieldnames):
196          raise TypeError("%s has exactly %d fieldnames (%d given)" %
197                          (self._name, len(self._fieldnames),
198                           len(tuple_or_bytes)))
199        self._SetValues(tuple_or_bytes)
200
201    def _FieldIndex(self, attr):
202      try:
203        return self._fieldnames.index(attr)
204      except ValueError:
205        raise AttributeError("'%s' has no attribute '%s'" %
206                             (self._name, attr))
207
208    def __getattr__(self, name):
209      return self._values[self._FieldIndex(name)]
210
211    def __setattr__(self, name, value):
212      # TODO: check value type against self._format and throw here, or else
213      # callers get an unhelpful exception when they call Pack().
214      self._values[self._FieldIndex(name)] = value
215
216    def offset(self, name):
217      if "." in name:
218        raise NotImplementedError("offset() on nested field")
219      return self._offsets[name]
220
221    @classmethod
222    def __len__(cls):
223      return cls._length
224
225    def __ne__(self, other):
226      return not self.__eq__(other)
227
228    def __eq__(self, other):
229      return (isinstance(other, self.__class__) and
230              self._name == other._name and
231              self._fieldnames == other._fieldnames and
232              self._values == other._values)
233
234    @staticmethod
235    def _MaybePackStruct(value):
236      if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
237        return value.Pack()
238      else:
239        return value
240
241    def Pack(self):
242      values = [self._MaybePackStruct(v) for v in self._values]
243      return struct.pack(self._format, *values)
244
245    def __str__(self):
246      def FieldDesc(index, name, value):
247        if isinstance(value, str):
248          if index in self._asciiz:
249            value = value.rstrip("\x00")
250          elif any(c not in string.printable for c in value):
251            value = value.encode("hex")
252        return "%s=%s" % (name, value)
253
254      descriptions = [
255          FieldDesc(i, n, v) for i, (n, v) in
256          enumerate(zip(self._fieldnames, self._values))]
257
258      return "%s(%s)" % (self._name, ", ".join(descriptions))
259
260    def __repr__(self):
261      return str(self)
262
263    def CPointer(self):
264      """Returns a C pointer to the serialized structure."""
265      buf = ctypes.create_string_buffer(self.Pack())
266      # Store the C buffer in the object so it doesn't get garbage collected.
267      super(CStruct, self).__setattr__("_buffer", buf)
268      return ctypes.addressof(self._buffer)
269
270  return CStruct
271
272
273def Read(data, struct_type):
274  length = len(struct_type)
275  return struct_type(data), data[length:]
276