• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31import six
32import struct
33
34import wire_format
35
36
37def _VarintSize(value):
38  """Compute the size of a varint value."""
39  if value <= 0x7f: return 1
40  if value <= 0x3fff: return 2
41  if value <= 0x1fffff: return 3
42  if value <= 0xfffffff: return 4
43  if value <= 0x7ffffffff: return 5
44  if value <= 0x3ffffffffff: return 6
45  if value <= 0x1ffffffffffff: return 7
46  if value <= 0xffffffffffffff: return 8
47  if value <= 0x7fffffffffffffff: return 9
48  return 10
49
50
51def _SignedVarintSize(value):
52  """Compute the size of a signed varint value."""
53  if value < 0: return 10
54  if value <= 0x7f: return 1
55  if value <= 0x3fff: return 2
56  if value <= 0x1fffff: return 3
57  if value <= 0xfffffff: return 4
58  if value <= 0x7ffffffff: return 5
59  if value <= 0x3ffffffffff: return 6
60  if value <= 0x1ffffffffffff: return 7
61  if value <= 0xffffffffffffff: return 8
62  if value <= 0x7fffffffffffffff: return 9
63  return 10
64
65
66def _VarintEncoder():
67  """Return an encoder for a basic varint value (does not include tag)."""
68
69  def EncodeVarint(write, value):
70    bits = value & 0x7f
71    value >>= 7
72    while value:
73      write(six.int2byte(0x80|bits))
74      bits = value & 0x7f
75      value >>= 7
76    return write(six.int2byte(bits))
77
78  return EncodeVarint
79
80
81def _SignedVarintEncoder():
82  """Return an encoder for a basic signed varint value (does not include
83  tag)."""
84
85  def EncodeSignedVarint(write, value):
86    if value < 0:
87      value += (1 << 64)
88    bits = value & 0x7f
89    value >>= 7
90    while value:
91      write(six.int2byte(0x80|bits))
92      bits = value & 0x7f
93      value >>= 7
94    return write(six.int2byte(bits))
95
96  return EncodeSignedVarint
97
98
99_EncodeVarint = _VarintEncoder()
100_EncodeSignedVarint = _SignedVarintEncoder()
101
102
103def _VarintBytes(value):
104  """Encode the given integer as a varint and return the bytes.  This is only
105  called at startup time so it doesn't need to be fast."""
106
107  pieces = []
108  _EncodeVarint(pieces.append, value)
109  return b"".join(pieces)
110
111
112def TagBytes(field_number, wire_type):
113  """Encode the given tag and return the bytes.  Only called at startup."""
114
115  return _VarintBytes(wire_format.PackTag(field_number, wire_type))
116
117
118def _SimpleEncoder(wire_type, encode_value, compute_value_size):
119  """Return a constructor for an encoder for fields of a particular type.
120
121  Args:
122      wire_type:  The field's wire type, for encoding tags.
123      encode_value:  A function which encodes an individual value, e.g.
124        _EncodeVarint().
125      compute_value_size:  A function which computes the size of an individual
126        value, e.g. _VarintSize().
127  """
128
129  def SpecificEncoder(field_number, is_repeated, is_packed):
130    if is_packed:
131      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
132      local_EncodeVarint = _EncodeVarint
133      def EncodePackedField(write, value):
134        write(tag_bytes)
135        size = 0
136        for element in value:
137          size += compute_value_size(element)
138        local_EncodeVarint(write, size)
139        for element in value:
140          encode_value(write, element)
141      return EncodePackedField
142    elif is_repeated:
143      tag_bytes = TagBytes(field_number, wire_type)
144      def EncodeRepeatedField(write, value):
145        for element in value:
146          write(tag_bytes)
147          encode_value(write, element)
148      return EncodeRepeatedField
149    else:
150      tag_bytes = TagBytes(field_number, wire_type)
151      def EncodeField(write, value):
152        write(tag_bytes)
153        return encode_value(write, value)
154      return EncodeField
155
156  return SpecificEncoder
157
158
159def _FloatingPointEncoder(wire_type, format):
160  """Return a constructor for an encoder for float fields.
161
162  This is like StructPackEncoder, but catches errors that may be due to
163  passing non-finite floating-point values to struct.pack, and makes a
164  second attempt to encode those values.
165
166  Args:
167      wire_type:  The field's wire type, for encoding tags.
168      format:  The format string to pass to struct.pack().
169  """
170
171  value_size = struct.calcsize(format)
172  if value_size == 4:
173    def EncodeNonFiniteOrRaise(write, value):
174      # Remember that the serialized form uses little-endian byte order.
175      if value == _POS_INF:
176        write(b'\x00\x00\x80\x7F')
177      elif value == _NEG_INF:
178        write(b'\x00\x00\x80\xFF')
179      elif value != value:           # NaN
180        write(b'\x00\x00\xC0\x7F')
181      else:
182        raise
183  elif value_size == 8:
184    def EncodeNonFiniteOrRaise(write, value):
185      if value == _POS_INF:
186        write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
187      elif value == _NEG_INF:
188        write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
189      elif value != value:                         # NaN
190        write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
191      else:
192        raise
193  else:
194    raise ValueError('Can\'t encode floating-point values that are '
195                     '%d bytes long (only 4 or 8)' % value_size)
196
197  def SpecificEncoder(field_number, is_repeated, is_packed):
198    local_struct_pack = struct.pack
199    if is_packed:
200      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
201      local_EncodeVarint = _EncodeVarint
202      def EncodePackedField(write, value):
203        write(tag_bytes)
204        local_EncodeVarint(write, len(value) * value_size)
205        for element in value:
206          # This try/except block is going to be faster than any code that
207          # we could write to check whether element is finite.
208          try:
209            write(local_struct_pack(format, element))
210          except SystemError:
211            EncodeNonFiniteOrRaise(write, element)
212      return EncodePackedField
213    elif is_repeated:
214      tag_bytes = TagBytes(field_number, wire_type)
215      def EncodeRepeatedField(write, value):
216        for element in value:
217          write(tag_bytes)
218          try:
219            write(local_struct_pack(format, element))
220          except SystemError:
221            EncodeNonFiniteOrRaise(write, element)
222      return EncodeRepeatedField
223    else:
224      tag_bytes = TagBytes(field_number, wire_type)
225      def EncodeField(write, value):
226        write(tag_bytes)
227        try:
228          write(local_struct_pack(format, value))
229        except SystemError:
230          EncodeNonFiniteOrRaise(write, value)
231      return EncodeField
232
233  return SpecificEncoder
234
235
236Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
237    wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
238
239UInt32Encoder = UInt64Encoder = _SimpleEncoder(
240    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
241
242FloatEncoder    = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
243
244DoubleEncoder   = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
245
246
247def BoolEncoder(field_number, is_repeated, is_packed):
248  """Returns an encoder for a boolean field."""
249
250  false_byte = b'\x00'
251  true_byte = b'\x01'
252  if is_packed:
253    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
254    local_EncodeVarint = _EncodeVarint
255    def EncodePackedField(write, value):
256      write(tag_bytes)
257      local_EncodeVarint(write, len(value))
258      for element in value:
259        if element:
260          write(true_byte)
261        else:
262          write(false_byte)
263    return EncodePackedField
264  elif is_repeated:
265    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
266    def EncodeRepeatedField(write, value):
267      for element in value:
268        write(tag_bytes)
269        if element:
270          write(true_byte)
271        else:
272          write(false_byte)
273    return EncodeRepeatedField
274  else:
275    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
276    def EncodeField(write, value):
277      write(tag_bytes)
278      if value:
279        return write(true_byte)
280      return write(false_byte)
281    return EncodeField
282
283
284def StringEncoder(field_number, is_repeated, is_packed):
285  """Returns an encoder for a string field."""
286
287  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
288  local_EncodeVarint = _EncodeVarint
289  local_len = len
290  assert not is_packed
291  if is_repeated:
292    def EncodeRepeatedField(write, value):
293      for element in value:
294        encoded = element.encode('utf-8')
295        write(tag)
296        local_EncodeVarint(write, local_len(encoded))
297        write(encoded)
298    return EncodeRepeatedField
299  else:
300    def EncodeField(write, value):
301      encoded = value.encode('utf-8')
302      write(tag)
303      local_EncodeVarint(write, local_len(encoded))
304      return write(encoded)
305    return EncodeField
306
307